diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..6a3e68d
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1 @@
+**/.DS_Store
\ No newline at end of file
diff --git a/README.md b/README.md
index 847260c..79bde7c 100644
--- a/README.md
+++ b/README.md
@@ -1,17 +1,88 @@
-## My Project
+# Official PyTorch Implementation of Adaptive Slot Attention: Object Discovery with Dynamic Slot Number
+[![ArXiv](https://img.shields.io/badge/ArXiv-2406.09196-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2406.09196)[![HomePage](https://img.shields.io/badge/HomePage-Visit-blue.svg?logo=homeadvisor&logoColor=f5f5f5)](https://kfan21.github.io/AdaSlot/)![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
+> [**Adaptive Slot Attention: Object Discovery with Dynamic Slot Number**](https://arxiv.org/abs/2406.09196)
+> [Ke Fan](https://kfan21.github.io/), [Zechen Bai](https://www.baizechen.site/), [Tianjun Xiao](http://tianjunxiao.com/), [Tong He](https://hetong007.github.io/), [Max Horn](https://expectationmax.github.io/), [Yanwei Fu†](http://yanweifu.github.io/), [Francesco Locatello](https://www.francescolocatello.com/), [Zheng Zhang](https://scholar.google.com/citations?hl=zh-CN&user=k0KiE4wAAAAJ)
-TODO: Fill this README out!
-Be sure to:
+This is the official implementation of the CVPR'24 paper [Adaptive Slot Attention: Object Discovery with Dynamic Slot Number]([CVPR 2024 Open Access Repository (thecvf.com)](https://openaccess.thecvf.com/content/CVPR2024/html/Fan_Adaptive_Slot_Attention_Object_Discovery_with_Dynamic_Slot_Number_CVPR_2024_paper.html)).
-* Change the title in this README
-* Edit your repository description on GitHub
+## Introduction
-## Security
+![framework](framework.png)
-See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information.
+Object-centric learning (OCL) uses slots to extract object representations, enhancing flexibility and interpretability. Slot attention, a common OCL method, refines slot representations with attention mechanisms but requires predefined slot numbers, ignoring object variability. To address this, a novel complexity-aware object auto-encoder framework introduces adaptive slot attention (AdaSlot), dynamically determining the optimal slot count based on data content through a discrete slot sampling module. A masked slot decoder suppresses unselected slots during decoding. Extensive testing shows this framework matches or exceeds fixed-slot models, adapting slot numbers based on instance complexity and promising further research opportunities.
-## License
+## Development Setup
-This project is licensed under the Apache-2.0 License.
+Installing AdaSlot requires at least python3.8. Installation can be done using [poetry](https://python-poetry.org/docs/#installation). After installing `poetry`, check out the repo and setup a development environment:
+
+```bash
+# install python3.8
+sudo apt update
+sudo apt install software-properties-common
+sudo add-apt-repository ppa:deadsnakes/ppa
+sudo apt install python3.8
+
+# install poetry with python3.8
+curl -sSL https://install.python-poetry.org | python3.8 - --version 1.2.0
+## add poetry to environment variable
+
+# create virtual environment with poetry
+cd $code_path
+poetry install -E timm
+```
+
+This installs the `ocl` package and the cli scripts used for running experiments in a poetry managed virtual environment. Activate the poetry virtual environment `poetry shell` before running the experiments.
+
+## Running experiments
+
+Experiments are defined in the folder `configs/experiment` and can be run
+by setting the experiment variable. For example, if we run OC-MOT on Cater dataset, we can follow:
+
+```bash
+poetry shell
+
+python -m ocl.cli.train +experiment=projects/bridging/dinosaur/movi_e_feat_rec_vitb16.yaml
+python -m ocl.cli.train +experiment=projects/bridging/dinosaur/movi_e_feat_rec_vitb16_adaslot.yaml +load_model_weight=PATH-TO-KMAX-SLOT-CHECKPOINT
+python -m ocl.cli.eval +experiment=projects/bridging/dinosaur/movi_e_feat_rec_vitb16_adaslot_eval.yaml ++load_checkpoint=PATH-TO-ADASLOT-CHECKPOINT
+
+python -m ocl.cli.train +experiment=projects/bridging/dinosaur/movi_c_feat_rec_vitb16.yaml
+python -m ocl.cli.train +experiment=projects/bridging/dinosaur/movi_c_feat_rec_vitb16_adaslot.yaml +load_model_weight=PATH-TO-KMAX-SLOT-CHECKPOINT
+python -m ocl.cli.eval +experiment=projects/bridging/dinosaur/movi_c_feat_rec_vitb16_adaslot_eval.yaml ++load_checkpoint=PATH-TO-ADASLOT-CHECKPOINT
+
+python -m ocl.cli.train +experiment=projects/bridging/dinosaur/coco_feat_rec_dino_base16.yaml
+python -m ocl.cli.train +experiment=projects/bridging/dinosaur/coco_feat_rec_dino_base16_adaslot.yaml +load_model_weight=PATH-TO-KMAX-SLOT-CHECKPOINT
+python -m ocl.cli.eval +experiment=projects/bridging/dinosaur/coco_feat_rec_dino_base16_adaslot_eval.yaml ++load_checkpoint=PATH-TO-ADASLOT-CHECKPOINT
+
+python -m ocl.cli.train +experiment=slot_attention/clevr10.yaml
+python -m ocl.cli.train +experiment=slot_attention/clevr10_adaslot.yaml +load_model_weight=PATH-TO-KMAX-SLOT-CHECKPOINT
+python -m ocl.cli.eval +experiment=slot_attention/clevr10_adaslot_eval.yaml ++load_checkpoint=PATH-TO-ADASLOT-CHECKPOINT
+```
+
+The result is saved in a timestamped subdirectory in `outputs/`, i.e. `outputs/OC-MOT/cater/_` in the above case. The prefix path `outputs` can be configured using the `experiment.root_output_path` variable.
+
+## Citation
+
+Please cite our paper if you find this repo useful!
+
+```bibtex
+@inproceedings{fan2024adaptive,
+ title={Adaptive slot attention: Object discovery with dynamic slot number},
+ author={Fan, Ke and Bai, Zechen and Xiao, Tianjun and He, Tong and Horn, Max and Fu, Yanwei and Locatello, Francesco and Zhang, Zheng},
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
+ pages={23062--23071},
+ year={2024}
+}
+```
+
+Related projects that this paper is developed upon:
+
+```bibtex
+@misc{oclf,
+ author = {Max Horn and Maximilian Seitzer and Andrii Zadaianchuk and Zixu Zhao and Dominik Zietlow and Florian Wenzel and Tianjun Xiao},
+ title = {Object Centric Learning Framework (version 0.1)},
+ year = {2023},
+ url = {https://github.com/amazon-science/object-centric-learning-framework},
+}
+```
diff --git a/configs/__init__.py b/configs/__init__.py
new file mode 100644
index 0000000..67795f7
--- /dev/null
+++ b/configs/__init__.py
@@ -0,0 +1 @@
+# Hydra needs this file to recognize the config folder when using hydra.main from console scripts
diff --git a/configs/dataset/clevr.yaml b/configs/dataset/clevr.yaml
new file mode 100644
index 0000000..1b851e9
--- /dev/null
+++ b/configs/dataset/clevr.yaml
@@ -0,0 +1,10 @@
+# Image dataset CLEVR based on https://github.com/deepmind/multi_object_datasets .
+defaults:
+ - webdataset
+
+train_shards: "/home/ubuntu/clevr_with_masks_new_splits/train/shard-{000000..000114}.tar"
+train_size: 70000
+val_shards: "/home/ubuntu/clevr_with_masks_new_splits/val/shard-{000000..000024}.tar"
+val_size: 15000
+test_shards: "/home/ubuntu/clevr_with_masks_new_splits/test/shard-{000000..000024}.tar"
+test_size: 15000
\ No newline at end of file
diff --git a/configs/dataset/clevr6.yaml b/configs/dataset/clevr6.yaml
new file mode 100644
index 0000000..880ce2b
--- /dev/null
+++ b/configs/dataset/clevr6.yaml
@@ -0,0 +1,18 @@
+# @package _global_
+# Image dataset containing instances from CLEVR with at most 6 objects in each scene.
+defaults:
+ - /dataset/clevr@dataset
+ - /plugins/subset_dataset@plugins.01_clevr6_subset
+ - _self_
+
+dataset:
+ # Values derived from running `bin/compute_dataset_size.py`
+ train_size: 26240
+ val_size: 5553
+ test_size: 5600
+
+plugins:
+ 01_clevr6_subset:
+ predicate: "${lambda_fn:'lambda visibility: visibility.sum() < 7'}"
+ fields:
+ - visibility
diff --git a/configs/dataset/clevr6_old_splits.yaml b/configs/dataset/clevr6_old_splits.yaml
new file mode 100644
index 0000000..861a25c
--- /dev/null
+++ b/configs/dataset/clevr6_old_splits.yaml
@@ -0,0 +1,18 @@
+# @package _global_
+# Image dataset containing instances from CLEVR with at most 6 objects in each scene.
+defaults:
+ - clevr_old_splits@dataset
+ - /plugins/subset_dataset@plugins.01_clevr6_subset
+ - _self_
+
+dataset:
+ # Values derived from running `bin/compute_dataset_size.py`
+ train_size: 29948
+ val_size: 3674
+ test_size: 3771
+
+plugins:
+ 01_clevr6_subset:
+ predicate: "${lambda_fn:'lambda visibility: visibility.sum() < 7'}"
+ fields:
+ - visibility
diff --git a/configs/dataset/clevr_old_splits.yaml b/configs/dataset/clevr_old_splits.yaml
new file mode 100644
index 0000000..5df7b76
--- /dev/null
+++ b/configs/dataset/clevr_old_splits.yaml
@@ -0,0 +1,10 @@
+# Image dataset CLEVR based on https://github.com/deepmind/multi_object_datasets .
+defaults:
+ - webdataset
+
+train_shards: ${s3_pipe:"s3://multi-object-webdatasets/clevr_with_masks/train/shard-{000000..000131}.tar"}
+train_size: 80000
+val_shards: ${s3_pipe:"s3://multi-object-webdatasets/clevr_with_masks/val/shard-{000000..000016}.tar"}
+val_size: 10000
+test_shards: ${s3_pipe:"s3://multi-object-webdatasets/clevr_with_masks/test/shard-{000000..000016}.tar"}
+test_size: 10000
diff --git a/configs/dataset/coco.yaml b/configs/dataset/coco.yaml
new file mode 100644
index 0000000..9740c89
--- /dev/null
+++ b/configs/dataset/coco.yaml
@@ -0,0 +1,11 @@
+# The coco2017 dataset with instance, stuff and caption annotations.
+defaults:
+ - webdataset
+
+train_shards: "/home/ubuntu/coco2017/train/shard-{000000..000412}.tar"
+train_size: 118287
+val_shards: "/home/ubuntu/coco2017/val/shard-{000000..000017}.tar"
+val_size: 5000
+test_shards: "/home/ubuntu/coco2017/test/shard-{000000..000126}.tar"
+test_size: 40670
+use_autopadding: true
diff --git a/configs/dataset/coco_nocrowd.yaml b/configs/dataset/coco_nocrowd.yaml
new file mode 100644
index 0000000..fb808df
--- /dev/null
+++ b/configs/dataset/coco_nocrowd.yaml
@@ -0,0 +1,12 @@
+# The coco2017 dataset with instance, stuff and caption annotations.
+# Validation dataset does not contain any crowd annotations.
+defaults:
+ - webdataset
+
+train_shards: ${dataset_prefix:"coco2017/train/shard-{000000..000412}.tar"}
+train_size: 118287
+val_shards: ${dataset_prefix:"coco2017/val_nocrowd/shard-{000000..000017}.tar"}
+val_size: 5000
+test_shards: ${dataset_prefix:"coco2017/test/shard-{000000..000126}.tar"}
+test_size: 40670
+use_autopadding: true
diff --git a/configs/dataset/movi_c.yaml b/configs/dataset/movi_c.yaml
new file mode 100644
index 0000000..f5b778d
--- /dev/null
+++ b/configs/dataset/movi_c.yaml
@@ -0,0 +1,31 @@
+# @package _global_
+defaults:
+ - webdataset@dataset
+ - _self_
+
+dataset:
+ train_shards: "/home/ubuntu/movi_c/train/shard-{000000..000298}.tar"
+ train_size: 9737
+ val_shards: "/home/ubuntu/movi_c/val/shard-{000000..000007}.tar"
+ val_size: 250
+ test_shards: "/home/ubuntu/movi_c/val/shard-{000000..000007}.tar"
+ test_size: 250
+ use_autopadding: true
+
+plugins:
+ 00_1_rename_fields:
+ _target_: ocl.plugins.RenameFields
+ train_mapping:
+ video: image
+ evaluation_mapping:
+ video: image
+ segmentations: mask
+ 00_2_adapt_mask_format:
+ _target_: ocl.plugins.SingleElementPreprocessing
+ training_transform: null
+ evaluation_transform:
+ _target_: ocl.preprocessing.IntegerToOneHotMask
+ output_axis: -4
+ max_instances: 10
+ ignore_typical_background: false
+ element_key: mask
diff --git a/configs/dataset/movi_c_image.yaml b/configs/dataset/movi_c_image.yaml
new file mode 100644
index 0000000..0add32c
--- /dev/null
+++ b/configs/dataset/movi_c_image.yaml
@@ -0,0 +1,25 @@
+# @package _global_
+# Image dataset containing subsampled frames from MOVI_C dataset.
+defaults:
+ - /dataset/movi_c
+ - /plugins/sample_frames_from_video@plugins.02_sample_frames
+ - _self_
+
+dataset:
+ # Values derived from running `bin/compute_dataset_size.py`.
+ train_size: 87633
+ val_size: 6000
+ test_size: 6000
+
+plugins:
+ 02_sample_frames:
+ n_frames_per_video: 9
+ n_eval_frames_per_video: -1
+ training_fields:
+ - image
+ evaluation_fields:
+ - image
+ - mask
+ dim: 0
+ seed: 457834752
+ shuffle_buffer_size: 1000
diff --git a/configs/dataset/movi_e.yaml b/configs/dataset/movi_e.yaml
new file mode 100644
index 0000000..3ef2e5a
--- /dev/null
+++ b/configs/dataset/movi_e.yaml
@@ -0,0 +1,30 @@
+# @package _global_
+defaults:
+ - webdataset@dataset
+ - _self_
+
+dataset:
+ train_shards: "/home/ubuntu/movi_e/train/shard-{000000..000679}.tar"
+ train_size: 9749
+ val_shards: "/home/ubuntu/movi_e/val/shard-{000000..000017}.tar"
+ val_size: 250
+ test_shards: "/home/ubuntu/movi_e/val/shard-{000000..000017}.tar"
+ test_size: 250
+ use_autopadding: true
+plugins:
+ 00_1_rename_fields:
+ _target_: ocl.plugins.RenameFields
+ train_mapping:
+ video: image
+ evaluation_mapping:
+ video: image
+ segmentations: mask
+ 00_2_adapt_mask_format:
+ _target_: ocl.plugins.SingleElementPreprocessing
+ training_transform: null
+ evaluation_transform:
+ _target_: ocl.preprocessing.IntegerToOneHotMask
+ output_axis: -4
+ max_instances: 23
+ ignore_typical_background: false
+ element_key: mask
diff --git a/configs/dataset/movi_e_image.yaml b/configs/dataset/movi_e_image.yaml
new file mode 100644
index 0000000..f90f65d
--- /dev/null
+++ b/configs/dataset/movi_e_image.yaml
@@ -0,0 +1,25 @@
+# @package _global_
+# Image dataset containing subsampled frames from MOVI_E dataset.
+defaults:
+ - /dataset/movi_e
+ - /plugins/sample_frames_from_video@plugins.02_sample_frames
+ - _self_
+
+dataset:
+ # Values derived from running `bin/compute_dataset_size.py`.
+ train_size: 87741
+ val_size: 6000
+ test_size: 6000
+
+plugins:
+ 02_sample_frames:
+ n_frames_per_video: 9
+ n_eval_frames_per_video: -1
+ training_fields:
+ - image
+ evaluation_fields:
+ - image
+ - mask
+ dim: 0
+ seed: 457834752
+ shuffle_buffer_size: 1000
diff --git a/configs/dataset/test.py b/configs/dataset/test.py
new file mode 100644
index 0000000..7d7547f
--- /dev/null
+++ b/configs/dataset/test.py
@@ -0,0 +1 @@
+lambda visibility, color, pixel_coords: True
\ No newline at end of file
diff --git a/configs/dataset/voc2012.yaml b/configs/dataset/voc2012.yaml
new file mode 100644
index 0000000..3cfefcd
--- /dev/null
+++ b/configs/dataset/voc2012.yaml
@@ -0,0 +1,11 @@
+# The PASCAL VOC 2012 dataset. Does not contain segmentation annotations.
+defaults:
+ - webdataset
+
+train_shards: ${dataset_prefix:"voc2012_detection/train/shard-{000000..000021}.tar"}
+train_size: 5717
+val_shards: ${dataset_prefix:"voc2012_detection/val/shard-{000000..000022}.tar"}
+val_size: 5823
+test_shards: ${dataset_prefix:"voc2012_detection/test/shard-{000000..000041}.tar"}
+test_size: 10991
+use_autopadding: true
diff --git a/configs/dataset/voc2012_trainaug.yaml b/configs/dataset/voc2012_trainaug.yaml
new file mode 100644
index 0000000..61b9d66
--- /dev/null
+++ b/configs/dataset/voc2012_trainaug.yaml
@@ -0,0 +1,11 @@
+# The PASCAL VOC 2012 dataset in the trainaug variant with instance segmentation masks.
+defaults:
+ - webdataset
+
+train_shards: "/home/ubuntu/voc2012/trainaug/shard-{000000..000040}.tar"
+train_size: 10582
+val_shards: "/home/ubuntu/voc2012/val/shard-{000000..000011}.tar"
+val_size: 1449
+test_shards: null
+test_size: null
+use_autopadding: true
diff --git a/configs/dataset/voc2012_trainval.yaml b/configs/dataset/voc2012_trainval.yaml
new file mode 100644
index 0000000..74eeebd
--- /dev/null
+++ b/configs/dataset/voc2012_trainval.yaml
@@ -0,0 +1,102 @@
+# The PASCAL VOC 2012 dataset, using joint train+val splits for training and validation.
+# This setting is often used in the unsupervised case.
+defaults:
+ - webdataset
+
+train_shards:
+ - ${dataset_prefix:"voc2012_detection/train/shard-000000.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000001.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000002.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000003.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000004.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000005.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000006.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000007.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000008.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000009.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000010.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000011.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000012.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000013.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000014.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000015.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000016.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000017.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000018.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000019.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000020.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000021.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000000.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000001.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000002.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000003.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000004.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000005.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000006.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000007.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000008.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000009.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000010.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000011.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000012.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000013.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000014.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000015.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000016.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000017.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000018.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000019.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000020.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000021.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000022.tar"}
+train_size: 11540
+val_shards:
+ - ${dataset_prefix:"voc2012_detection/train/shard-000000.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000001.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000002.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000003.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000004.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000005.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000006.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000007.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000008.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000009.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000010.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000011.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000012.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000013.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000014.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000015.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000016.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000017.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000018.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000019.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000020.tar"}
+ - ${dataset_prefix:"voc2012_detection/train/shard-000021.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000000.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000001.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000002.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000003.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000004.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000005.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000006.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000007.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000008.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000009.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000010.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000011.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000012.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000013.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000014.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000015.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000016.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000017.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000018.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000019.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000020.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000021.tar"}
+ - ${dataset_prefix:"voc2012_detection/val/shard-000022.tar"}
+val_size: 11540
+test_shards: ${dataset_prefix:"voc2012_detection/test/shard-{000000..000041}.tar"}
+test_size: 10991
+use_autopadding: true
diff --git a/configs/experiment/_output_path.yaml b/configs/experiment/_output_path.yaml
new file mode 100644
index 0000000..e405689
--- /dev/null
+++ b/configs/experiment/_output_path.yaml
@@ -0,0 +1,8 @@
+# @package hydra
+
+run:
+ dir: ${oc.select:experiment.root_output_folder,outputs}/${hydra:runtime.choices.experiment}/${now:%Y-%m-%d_%H-%M-%S}
+sweep:
+ dir: ${oc.select:experiment.root_output_folder,multirun}
+ subdir: ${hydra:runtime.choices.experiment}/${now:%Y-%m-%d_%H-%M-%S}
+output_subdir: config
diff --git a/configs/experiment/projects/bridging/dinosaur/_base_feature_recon.yaml b/configs/experiment/projects/bridging/dinosaur/_base_feature_recon.yaml
new file mode 100644
index 0000000..031321f
--- /dev/null
+++ b/configs/experiment/projects/bridging/dinosaur/_base_feature_recon.yaml
@@ -0,0 +1,87 @@
+# @package _global_
+# Default parameters for slot attention with a ViT decoder for feature reconstruction.
+defaults:
+ - /experiment/_output_path
+ - /training_config
+ - /feature_extractor/timm_model@models.feature_extractor
+ - /perceptual_grouping/slot_attention@models.perceptual_grouping
+ - /plugins/optimization@plugins.optimize_parameters
+ - /optimizers/adam@plugins.optimize_parameters.optimizer
+ - /lr_schedulers/exponential_decay@plugins.optimize_parameters.lr_scheduler
+ - _self_
+
+trainer:
+ gradient_clip_val: 1.0
+
+models:
+ feature_extractor:
+ model_name: vit_small_patch16_224_dino
+ pretrained: false
+ freeze: true
+ feature_level: 12
+
+ perceptual_grouping:
+ input_dim: 384
+ feature_dim: ${.object_dim}
+ object_dim: ${models.conditioning.object_dim}
+ use_projection_bias: false
+ positional_embedding:
+ _target_: ocl.neural_networks.wrappers.Sequential
+ _args_:
+ - _target_: ocl.neural_networks.positional_embedding.DummyPositionEmbed
+ - _target_: ocl.neural_networks.build_two_layer_mlp
+ input_dim: ${....input_dim}
+ output_dim: ${....feature_dim}
+ hidden_dim: ${....input_dim}
+ initial_layer_norm: true
+ ff_mlp:
+ _target_: ocl.neural_networks.build_two_layer_mlp
+ input_dim: ${..object_dim}
+ output_dim: ${..object_dim}
+ hidden_dim: "${eval_lambda:'lambda dim: 4 * dim', ${..object_dim}}"
+ initial_layer_norm: true
+ residual: true
+
+ object_decoder:
+ object_dim: ${models.perceptual_grouping.object_dim}
+ output_dim: ${models.perceptual_grouping.input_dim}
+ num_patches: 196
+ object_features_path: perceptual_grouping.objects
+ target_path: feature_extractor.features
+ image_path: input.image
+
+plugins:
+ optimize_parameters:
+ optimizer:
+ lr: 0.0004
+ lr_scheduler:
+ decay_rate: 0.5
+ decay_steps: 100000
+ warmup_steps: 10000
+
+losses:
+ mse:
+ _target_: ocl.losses.ReconstructionLoss
+ loss_type: mse
+ input_path: object_decoder.reconstruction
+ target_path: object_decoder.target # Object decoder does some resizing.
+
+visualizations:
+ input:
+ _target_: ocl.visualizations.Image
+ denormalization:
+ _target_: ocl.preprocessing.Denormalize
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ image_path: input.image
+ masks:
+ _target_: ocl.visualizations.Mask
+ mask_path: object_decoder.masks_as_image
+ pred_segmentation:
+ _target_: ocl.visualizations.Segmentation
+ denormalization:
+ _target_: ocl.preprocessing.Denormalize
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ image_path: input.image
+ mask_path: object_decoder.masks_as_image
diff --git a/configs/experiment/projects/bridging/dinosaur/_base_feature_recon_gumbel.yaml b/configs/experiment/projects/bridging/dinosaur/_base_feature_recon_gumbel.yaml
new file mode 100644
index 0000000..88e70bb
--- /dev/null
+++ b/configs/experiment/projects/bridging/dinosaur/_base_feature_recon_gumbel.yaml
@@ -0,0 +1,94 @@
+# @package _global_
+# Default parameters for slot attention with a ViT decoder for feature reconstruction.
+defaults:
+ - /experiment/_output_path
+ - /training_config
+ - /feature_extractor/timm_model@models.feature_extractor
+ - /perceptual_grouping/slot_attention_gumbel_v1@models.perceptual_grouping
+ - /plugins/optimization@plugins.optimize_parameters
+ - /optimizers/adam@plugins.optimize_parameters.optimizer
+ - /lr_schedulers/exponential_decay@plugins.optimize_parameters.lr_scheduler
+ - _self_
+
+trainer:
+ gradient_clip_val: 1.0
+
+models:
+ feature_extractor:
+ model_name: vit_small_patch16_224_dino
+ pretrained: false
+ freeze: true
+ feature_level: 12
+
+ perceptual_grouping:
+ input_dim: 384
+ feature_dim: ${.object_dim}
+ object_dim: ${models.conditioning.object_dim}
+ use_projection_bias: false
+ positional_embedding:
+ _target_: ocl.neural_networks.wrappers.Sequential
+ _args_:
+ - _target_: ocl.neural_networks.positional_embedding.DummyPositionEmbed
+ - _target_: ocl.neural_networks.build_two_layer_mlp
+ input_dim: ${....input_dim}
+ output_dim: ${....feature_dim}
+ hidden_dim: ${....input_dim}
+ initial_layer_norm: true
+ ff_mlp:
+ _target_: ocl.neural_networks.build_two_layer_mlp
+ input_dim: ${..object_dim}
+ output_dim: ${..object_dim}
+ hidden_dim: "${eval_lambda:'lambda dim: 4 * dim', ${..object_dim}}"
+ initial_layer_norm: true
+ residual: true
+ single_gumbel_score_network:
+ _target_: ocl.neural_networks.build_two_layer_mlp
+ input_dim: ${..object_dim}
+ output_dim: 2
+ hidden_dim: "${eval_lambda:'lambda dim: 4 * dim', ${..object_dim}}"
+ initial_layer_norm: true
+ residual: false
+ object_decoder:
+ object_dim: ${models.perceptual_grouping.object_dim}
+ output_dim: ${models.perceptual_grouping.input_dim}
+ num_patches: 196
+ object_features_path: perceptual_grouping.objects
+ target_path: feature_extractor.features
+ image_path: input.image
+
+plugins:
+ optimize_parameters:
+ optimizer:
+ lr: 0.0004
+ lr_scheduler:
+ decay_rate: 0.5
+ decay_steps: 100000
+ warmup_steps: 10000
+
+losses:
+ mse:
+ _target_: ocl.losses.ReconstructionLoss
+ loss_type: mse
+ input_path: object_decoder.reconstruction
+ target_path: object_decoder.target # Object decoder does some resizing.
+
+
+visualizations:
+ input:
+ _target_: ocl.visualizations.Image
+ denormalization:
+ _target_: ocl.preprocessing.Denormalize
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ image_path: input.image
+ masks:
+ _target_: ocl.visualizations.Mask
+ mask_path: object_decoder.masks_as_image
+ pred_segmentation:
+ _target_: ocl.visualizations.Segmentation
+ denormalization:
+ _target_: ocl.preprocessing.Denormalize
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ image_path: input.image
+ mask_path: object_decoder.masks_as_image
diff --git a/configs/experiment/projects/bridging/dinosaur/_metrics_clevr_patch.yaml b/configs/experiment/projects/bridging/dinosaur/_metrics_clevr_patch.yaml
new file mode 100644
index 0000000..dd87521
--- /dev/null
+++ b/configs/experiment/projects/bridging/dinosaur/_metrics_clevr_patch.yaml
@@ -0,0 +1,13 @@
+# @package _global_
+defaults:
+ - /metrics/ari_metric@evaluation_metrics.ari
+ - /metrics/average_best_overlap_metric@evaluation_metrics.abo
+
+evaluation_metrics:
+ ari:
+ prediction_path: masks_as_image
+ target_path: input.mask
+ abo:
+ prediction_path: masks_as_image
+ target_path: input.mask
+ ignore_background: true
diff --git a/configs/experiment/projects/bridging/dinosaur/_metrics_coco.yaml b/configs/experiment/projects/bridging/dinosaur/_metrics_coco.yaml
new file mode 100644
index 0000000..eb5eca5
--- /dev/null
+++ b/configs/experiment/projects/bridging/dinosaur/_metrics_coco.yaml
@@ -0,0 +1,25 @@
+# @package _global_
+defaults:
+ - /metrics/ari_metric@evaluation_metrics.instance_mask_ari
+ - /metrics/unsupervised_mask_iou_metric@evaluation_metrics.instance_abo
+ - /metrics/mask_corloc_metric@evaluation_metrics.instance_mask_corloc
+
+evaluation_metrics:
+ instance_mask_ari:
+ prediction_path: object_decoder.masks_as_image
+ target_path: input.instance_mask
+ foreground: False
+ convert_target_one_hot: True
+ ignore_overlaps: True
+ instance_abo:
+ prediction_path: object_decoder.masks_as_image
+ target_path: input.instance_mask
+ use_threshold: False
+ matching: best_overlap
+ ignore_overlaps: True
+ instance_mask_corloc:
+ prediction_path: object_decoder.masks_as_image
+ target_path: input.instance_mask
+ use_threshold: False
+ ignore_overlaps: True
+
diff --git a/configs/experiment/projects/bridging/dinosaur/_preprocessing_coco_dino_feature_recon_ccrop.yaml b/configs/experiment/projects/bridging/dinosaur/_preprocessing_coco_dino_feature_recon_ccrop.yaml
new file mode 100644
index 0000000..7503630
--- /dev/null
+++ b/configs/experiment/projects/bridging/dinosaur/_preprocessing_coco_dino_feature_recon_ccrop.yaml
@@ -0,0 +1,69 @@
+# @package _global_
+defaults:
+ - /plugins/data_preprocessing@plugins.03a_preprocessing
+ - /plugins/multi_element_preprocessing@plugins.03b_preprocessing
+
+plugins:
+ 03a_preprocessing:
+ evaluation_fields:
+ - image
+ - instance_mask
+ - instance_category
+ evaluation_transform:
+ _target_: torchvision.transforms.Compose
+ transforms:
+ - _target_: ocl.preprocessing.InstanceMasksToDenseMasks
+ - _target_: ocl.preprocessing.AddSegmentationMaskFromInstanceMask
+ # Drop instance_category again as some images do not contain it
+ - "${lambda_fn:'lambda data: {k: v for k, v in data.items() if k != \"instance_category\"}'}"
+ - _target_: ocl.preprocessing.AddEmptyMasks
+ mask_keys:
+ - instance_mask
+ - segmentation_mask
+
+ 03b_preprocessing:
+ training_transforms:
+ image:
+ _target_: torchvision.transforms.Compose
+ transforms:
+ - _target_: torchvision.transforms.ToTensor
+ - _target_: torchvision.transforms.Resize
+ size: 224
+ interpolation: "${torchvision_interpolation_mode:BICUBIC}"
+ - "${lambda_fn:'lambda image: image.clamp(0.0, 1.0)'}" # Bicubic interpolation can get out of range
+ - _target_: torchvision.transforms.CenterCrop
+ size: 224
+ - _target_: torchvision.transforms.RandomHorizontalFlip
+ - _target_: torchvision.transforms.Normalize
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ evaluation_transforms:
+ image:
+ _target_: torchvision.transforms.Compose
+ transforms:
+ - _target_: torchvision.transforms.ToTensor
+ - _target_: torchvision.transforms.Resize
+ size: 224
+ interpolation: "${torchvision_interpolation_mode:BICUBIC}"
+ - "${lambda_fn:'lambda image: image.clamp(0.0, 1.0)'}" # Bicubic interpolation can get out of range
+ - _target_: torchvision.transforms.CenterCrop
+ size: 224
+ - _target_: torchvision.transforms.Normalize
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ instance_mask:
+ _target_: torchvision.transforms.Compose
+ transforms:
+ - _target_: ocl.preprocessing.DenseMaskToTensor
+ - _target_: ocl.preprocessing.ResizeNearestExact
+ size: 224
+ - _target_: torchvision.transforms.CenterCrop
+ size: 224
+ segmentation_mask:
+ _target_: torchvision.transforms.Compose
+ transforms:
+ - _target_: ocl.preprocessing.DenseMaskToTensor
+ - _target_: ocl.preprocessing.ResizeNearestExact
+ size: 224
+ - _target_: torchvision.transforms.CenterCrop
+ size: 224
diff --git a/configs/experiment/projects/bridging/dinosaur/_preprocessing_coco_dino_feature_recon_origres.yaml b/configs/experiment/projects/bridging/dinosaur/_preprocessing_coco_dino_feature_recon_origres.yaml
new file mode 100644
index 0000000..3d7d63e
--- /dev/null
+++ b/configs/experiment/projects/bridging/dinosaur/_preprocessing_coco_dino_feature_recon_origres.yaml
@@ -0,0 +1,65 @@
+# @package _global_
+defaults:
+ - /plugins/data_preprocessing@plugins.03a_preprocessing
+ - /plugins/multi_element_preprocessing@plugins.03b_preprocessing
+
+plugins:
+ 03a_preprocessing:
+ evaluation_fields:
+ - image
+ - instance_mask
+ - instance_category
+ evaluation_transform:
+ _target_: torchvision.transforms.Compose
+ transforms:
+ - _target_: ocl.preprocessing.InstanceMasksToDenseMasks
+ - _target_: ocl.preprocessing.AddSegmentationMaskFromInstanceMask
+ # Drop instance_category again as some images do not contain it
+ - "${lambda_fn:'lambda data: {k: v for k, v in data.items() if k != \"instance_category\"}'}"
+ - _target_: ocl.preprocessing.AddEmptyMasks
+ mask_keys:
+ - instance_mask
+ - segmentation_mask
+
+ 03b_preprocessing:
+ training_transforms:
+ image:
+ _target_: torchvision.transforms.Compose
+ transforms:
+ - _target_: torchvision.transforms.ToTensor
+ - _target_: torchvision.transforms.Resize
+ _convert_: partial
+ size: [224, 224]
+ interpolation: "${torchvision_interpolation_mode:BICUBIC}"
+ - "${lambda_fn:'lambda image: image.clamp(0.0, 1.0)'}" # Bicubic interpolation can get out of range
+ - _target_: torchvision.transforms.RandomHorizontalFlip
+ - _target_: torchvision.transforms.Normalize
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ evaluation_transforms:
+ image:
+ _target_: torchvision.transforms.Compose
+ transforms:
+ - _target_: torchvision.transforms.ToTensor
+ - _target_: torchvision.transforms.Resize
+ _convert_: partial
+ size: [224, 224]
+ interpolation: "${torchvision_interpolation_mode:BICUBIC}"
+ - "${lambda_fn:'lambda image: image.clamp(0.0, 1.0)'}" # Bicubic interpolation can get out of range
+ - _target_: torchvision.transforms.Normalize
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ instance_mask:
+ _target_: torchvision.transforms.Compose
+ transforms:
+ - _target_: ocl.preprocessing.DenseMaskToTensor
+ - _target_: ocl.preprocessing.ResizeNearestExact
+ _convert_: partial
+ size: [224, 224]
+ segmentation_mask:
+ _target_: torchvision.transforms.Compose
+ transforms:
+ - _target_: ocl.preprocessing.DenseMaskToTensor
+ - _target_: ocl.preprocessing.ResizeNearestExact
+ _convert_: partial
+ size: [224, 224]
diff --git a/configs/experiment/projects/bridging/dinosaur/_preprocessing_coco_dino_feature_recon_randcrop.yaml b/configs/experiment/projects/bridging/dinosaur/_preprocessing_coco_dino_feature_recon_randcrop.yaml
new file mode 100644
index 0000000..4dc7624
--- /dev/null
+++ b/configs/experiment/projects/bridging/dinosaur/_preprocessing_coco_dino_feature_recon_randcrop.yaml
@@ -0,0 +1,69 @@
+# @package _global_
+defaults:
+ - /plugins/data_preprocessing@plugins.03a_preprocessing
+ - /plugins/multi_element_preprocessing@plugins.03b_preprocessing
+
+plugins:
+ 03a_preprocessing:
+ evaluation_fields:
+ - image
+ - instance_mask
+ - instance_category
+ evaluation_transform:
+ _target_: torchvision.transforms.Compose
+ transforms:
+ - _target_: ocl.preprocessing.InstanceMasksToDenseMasks
+ - _target_: ocl.preprocessing.AddSegmentationMaskFromInstanceMask
+ # Drop instance_category again as some images do not contain it
+ - "${lambda_fn:'lambda data: {k: v for k, v in data.items() if k != \"instance_category\"}'}"
+ - _target_: ocl.preprocessing.AddEmptyMasks
+ mask_keys:
+ - instance_mask
+ - segmentation_mask
+
+ 03b_preprocessing:
+ training_transforms:
+ image:
+ _target_: torchvision.transforms.Compose
+ transforms:
+ - _target_: torchvision.transforms.ToTensor
+ - _target_: torchvision.transforms.Resize
+ size: 224
+ interpolation: "${torchvision_interpolation_mode:BICUBIC}"
+ - "${lambda_fn:'lambda image: image.clamp(0.0, 1.0)'}" # Bicubic interpolation can get out of range
+ - _target_: torchvision.transforms.RandomCrop
+ size: 224
+ - _target_: torchvision.transforms.RandomHorizontalFlip
+ - _target_: torchvision.transforms.Normalize
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ evaluation_transforms:
+ image:
+ _target_: torchvision.transforms.Compose
+ transforms:
+ - _target_: torchvision.transforms.ToTensor
+ - _target_: torchvision.transforms.Resize
+ size: 224
+ interpolation: "${torchvision_interpolation_mode:BICUBIC}"
+ - "${lambda_fn:'lambda image: image.clamp(0.0, 1.0)'}" # Bicubic interpolation can get out of range
+ - _target_: torchvision.transforms.CenterCrop
+ size: 224
+ - _target_: torchvision.transforms.Normalize
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ instance_mask:
+ _target_: torchvision.transforms.Compose
+ transforms:
+ - _target_: ocl.preprocessing.DenseMaskToTensor
+ - _target_: ocl.preprocessing.ResizeNearestExact
+ size: 224
+ - _target_: torchvision.transforms.CenterCrop
+ size: 224
+ segmentation_mask:
+ _target_: torchvision.transforms.Compose
+ transforms:
+ - _target_: ocl.preprocessing.DenseMaskToTensor
+ - _target_: ocl.preprocessing.ResizeNearestExact
+ size: 224
+ - _target_: torchvision.transforms.CenterCrop
+ size: 224
diff --git a/configs/experiment/projects/bridging/dinosaur/_preprocessing_movi_dino_feature_recon.yaml b/configs/experiment/projects/bridging/dinosaur/_preprocessing_movi_dino_feature_recon.yaml
new file mode 100644
index 0000000..a40fa6b
--- /dev/null
+++ b/configs/experiment/projects/bridging/dinosaur/_preprocessing_movi_dino_feature_recon.yaml
@@ -0,0 +1,37 @@
+# @package _global_
+defaults:
+ - /plugins/multi_element_preprocessing@plugins.03_preprocessing
+ - _self_
+
+plugins:
+ 03_preprocessing:
+ training_transforms:
+ image:
+ _target_: torchvision.transforms.Compose
+ transforms:
+ - _target_: torchvision.transforms.ToTensor
+ - _target_: torchvision.transforms.Resize
+ size: 224
+ interpolation: "${torchvision_interpolation_mode:BICUBIC}"
+ - "${lambda_fn:'lambda image: image.clamp(0.0, 1.0)'}" # Bicubic interpolation can get out of range
+ - _target_: torchvision.transforms.Normalize
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ evaluation_transforms:
+ image:
+ _target_: torchvision.transforms.Compose
+ transforms:
+ - _target_: torchvision.transforms.ToTensor
+ - _target_: torchvision.transforms.Resize
+ size: 224
+ interpolation: "${torchvision_interpolation_mode:BICUBIC}"
+ - "${lambda_fn:'lambda image: image.clamp(0.0, 1.0)'}" # Bicubic interpolation can get out of range
+ - _target_: torchvision.transforms.Normalize
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ mask:
+ _target_: torchvision.transforms.Compose
+ transforms:
+ - _target_: ocl.preprocessing.MultiMaskToTensor
+ - _target_: ocl.preprocessing.ResizeNearestExact
+ size: 128
diff --git a/configs/experiment/projects/bridging/dinosaur/_preprocessing_voc2012_segm_dino_feature_recon.yaml b/configs/experiment/projects/bridging/dinosaur/_preprocessing_voc2012_segm_dino_feature_recon.yaml
new file mode 100644
index 0000000..23d9b3c
--- /dev/null
+++ b/configs/experiment/projects/bridging/dinosaur/_preprocessing_voc2012_segm_dino_feature_recon.yaml
@@ -0,0 +1,95 @@
+# @package _global_
+defaults:
+ - /plugins/multi_element_preprocessing@plugins.02a_format_consistency
+ - /plugins/data_preprocessing@plugins.02b_format_consistency
+ - /plugins/data_preprocessing@plugins.03a_preprocessing
+ - /plugins/multi_element_preprocessing@plugins.03b_preprocessing
+
+plugins:
+ # Make VOC2012 cosistent with COCO.
+ 02a_format_consistency:
+ evaluation_transforms:
+ # Convert to one-hot encoding.
+ segmentation-instance:
+ _target_: ocl.preprocessing.IntegerToOneHotMask
+
+ 02b_format_consistency:
+ evaluation_fields:
+ - "segmentation-instance"
+ - "segmentation-class"
+ - "image"
+ evaluation_transform:
+ _target_: torchvision.transforms.Compose
+ transforms:
+ # Create segmentation mask.
+ - _target_: ocl.preprocessing.VOCInstanceMasksToDenseMasks
+ instance_mask_key: segmentation-instance
+ class_mask_key: segmentation-class
+ classes_key: instance_category
+ - _target_: ocl.preprocessing.RenameFields
+ mapping:
+ segmentation-instance: instance_mask
+ 03a_preprocessing:
+ evaluation_fields:
+ - image
+ - instance_mask
+ - instance_category
+ evaluation_transform:
+ _target_: torchvision.transforms.Compose
+ transforms:
+ # This is not needed for VOC.
+ # - _target_: ocl.preprocessing.InstanceMasksToDenseMasks
+ - _target_: ocl.preprocessing.AddSegmentationMaskFromInstanceMask
+ # Drop instance_category again as some images do not contain it
+ - "${lambda_fn:'lambda data: {k: v for k, v in data.items() if k != \"instance_category\"}'}"
+ - _target_: ocl.preprocessing.AddEmptyMasks
+ mask_keys:
+ - instance_mask
+ - segmentation_mask
+
+ 03b_preprocessing:
+ training_transforms:
+ image:
+ _target_: torchvision.transforms.Compose
+ transforms:
+ - _target_: torchvision.transforms.ToTensor
+ - _target_: torchvision.transforms.Resize
+ size: 224
+ interpolation: "${torchvision_interpolation_mode:BICUBIC}"
+ - "${lambda_fn:'lambda image: image.clamp(0.0, 1.0)'}" # Bicubic interpolation can get out of range
+ - _target_: torchvision.transforms.RandomCrop
+ size: 224
+ - _target_: torchvision.transforms.RandomHorizontalFlip
+ - _target_: torchvision.transforms.Normalize
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ evaluation_transforms:
+ image:
+ _target_: torchvision.transforms.Compose
+ transforms:
+ - _target_: torchvision.transforms.ToTensor
+ - _target_: torchvision.transforms.Resize
+ size: 224
+ interpolation: "${torchvision_interpolation_mode:BICUBIC}"
+ - "${lambda_fn:'lambda image: image.clamp(0.0, 1.0)'}" # Bicubic interpolation can get out of range
+ - _target_: torchvision.transforms.CenterCrop
+ size: 224
+ - _target_: torchvision.transforms.Normalize
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ instance_mask:
+ _target_: torchvision.transforms.Compose
+ transforms:
+ - _target_: ocl.preprocessing.DenseMaskToTensor
+ - _target_: ocl.preprocessing.ResizeNearestExact
+ size: 224
+ - _target_: torchvision.transforms.CenterCrop
+ size: 224
+ segmentation_mask:
+ _target_: torchvision.transforms.Compose
+ transforms:
+ - _target_: ocl.preprocessing.DenseMaskToTensor
+ - _target_: ocl.preprocessing.ResizeNearestExact
+ size: 224
+ - _target_: torchvision.transforms.CenterCrop
+ size: 224
diff --git a/configs/experiment/projects/bridging/dinosaur/coco_feat_rec_dino_base16.yaml b/configs/experiment/projects/bridging/dinosaur/coco_feat_rec_dino_base16.yaml
new file mode 100644
index 0000000..d15ccfd
--- /dev/null
+++ b/configs/experiment/projects/bridging/dinosaur/coco_feat_rec_dino_base16.yaml
@@ -0,0 +1,38 @@
+# @package _global_
+defaults:
+ - /conditioning/random@models.conditioning
+ - /experiment/projects/bridging/dinosaur/_base_feature_recon
+ - /neural_networks/mlp@models.object_decoder.decoder
+ - /dataset: coco
+ - /experiment/projects/bridging/dinosaur/_preprocessing_coco_dino_feature_recon_ccrop
+ - /experiment/projects/bridging/dinosaur/_metrics_coco
+ - _self_
+
+# The following parameters assume training on 8 GPUs, leading to an effective batch size of 64.
+trainer:
+ gpus: 8
+ max_steps: 200000
+ max_epochs: null
+ strategy: ddp
+
+dataset:
+ num_workers: 4
+ batch_size: 8
+
+models:
+ conditioning:
+ n_slots: 7
+ object_dim: 256
+
+ feature_extractor:
+ model_name: vit_base_patch16_224_dino
+ pretrained: true
+ freeze: true
+
+ perceptual_grouping:
+ input_dim: 768
+
+ object_decoder:
+ _target_: ocl.decoding.PatchDecoder
+ decoder:
+ features: [2048, 2048, 2048]
diff --git a/configs/experiment/projects/bridging/dinosaur/coco_feat_rec_dino_base16_adaslot.yaml b/configs/experiment/projects/bridging/dinosaur/coco_feat_rec_dino_base16_adaslot.yaml
new file mode 100644
index 0000000..f3b1ef8
--- /dev/null
+++ b/configs/experiment/projects/bridging/dinosaur/coco_feat_rec_dino_base16_adaslot.yaml
@@ -0,0 +1,65 @@
+# @package _global_
+defaults:
+ - /conditioning/random@models.conditioning
+ - /experiment/projects/bridging/dinosaur/_base_feature_recon_gumbel
+ - /neural_networks/mlp@models.object_decoder.decoder
+ - /dataset: coco
+ - /experiment/projects/bridging/dinosaur/_preprocessing_coco_dino_feature_recon_ccrop
+ - /experiment/projects/bridging/dinosaur/_metrics_coco
+ - /metrics/tensor_statistic@training_metrics.hard_keep_decision
+ - /metrics/tensor_statistic@training_metrics.slots_keep_prob
+ - _self_
+
+# The following parameters assume training on 8 GPUs, leading to an effective batch size of 64.
+trainer:
+ gpus: 8
+ max_steps: 500000
+ max_epochs: null
+ strategy: ddp
+
+dataset:
+ num_workers: 4
+ batch_size: 8
+
+models:
+ _target_: ocl.models.image_grouping_adaslot.GroupingImgGumbel
+ conditioning:
+ n_slots: 33
+ object_dim: 256
+
+ feature_extractor:
+ model_name: vit_base_patch16_224_dino
+ pretrained: true
+ freeze: true
+
+ perceptual_grouping:
+ input_dim: 768
+ low_bound: 0
+
+ object_decoder:
+ _target_: ocl.decoding.PatchDecoderGumbelV1
+ decoder:
+ features: [2048, 2048, 2048]
+ left_mask_path: None
+ mask_type: mask_normalized
+
+losses:
+ sparse_penalty:
+ _target_: ocl.losses.SparsePenalty
+ linear_weight: 0.1
+ quadratic_weight: 0.0
+ quadratic_bias: 0.5
+ input_path: hard_keep_decision
+
+# outputs["hard_keep_decision"] = perceptual_grouping_output["hard_keep_decision"]
+# outputs["slots_keep_prob"]
+training_metrics:
+ hard_keep_decision:
+ path: hard_keep_decision
+ reduction: sum
+
+ slots_keep_prob:
+ path: slots_keep_prob
+ reduction: mean
+
+load_model_weight: /home/ubuntu/GitLab/bags-of-tricks/object-centric-learning-models/outputs/projects/bridging/dinosaur/coco_feat_rec_dino_base16.yaml/2023-05-02_16-57-16/lightning_logs/version_0/checkpoints/epoch=95-step=177408.ckpt
\ No newline at end of file
diff --git a/configs/experiment/projects/bridging/dinosaur/coco_feat_rec_dino_base16_adaslot_eval.yaml b/configs/experiment/projects/bridging/dinosaur/coco_feat_rec_dino_base16_adaslot_eval.yaml
new file mode 100644
index 0000000..a949181
--- /dev/null
+++ b/configs/experiment/projects/bridging/dinosaur/coco_feat_rec_dino_base16_adaslot_eval.yaml
@@ -0,0 +1,116 @@
+# @package _global_
+defaults:
+ - /conditioning/random@models.conditioning
+ - /experiment/projects/bridging/dinosaur/_base_feature_recon_gumbel
+ - /neural_networks/mlp@models.object_decoder.decoder
+ - /dataset: coco
+ - /experiment/projects/bridging/dinosaur/_preprocessing_coco_dino_feature_recon_ccrop
+ - /experiment/projects/bridging/dinosaur/_metrics_coco
+ - /metrics/tensor_statistic@evaluation_metrics.hard_keep_decision
+ - /metrics/tensor_statistic@evaluation_metrics.slots_keep_prob
+ - /metrics/ami_metric@evaluation_metrics.ami
+ - /metrics/nmi_metric@evaluation_metrics.nmi
+ - /metrics/purity_metric@evaluation_metrics.purity
+ - /metrics/precision_metric@evaluation_metrics.precision
+ - /metrics/recall_metric@evaluation_metrics.recall
+ - /metrics/f1_metric@evaluation_metrics.f1
+ - _self_
+
+# The following parameters assume training on 8 GPUs, leading to an effective batch size of 64.
+trainer:
+ gpus: 8
+ max_steps: 500000
+ max_epochs: null
+ strategy: ddp
+
+dataset:
+ num_workers: 4
+ batch_size: 8
+
+models:
+ _target_: ocl.models.image_grouping_adaslot.GroupingImgGumbel
+ conditioning:
+ n_slots: 33
+ object_dim: 256
+
+ feature_extractor:
+ model_name: vit_base_patch16_224_dino
+ pretrained: true
+ freeze: true
+
+ perceptual_grouping:
+ input_dim: 768
+ low_bound: 0
+
+ object_decoder:
+ _target_: ocl.decoding.PatchDecoderGumbelV1
+ decoder:
+ features: [2048, 2048, 2048]
+ left_mask_path: None
+ mask_type: mask_normalized
+
+losses:
+ sparse_penalty:
+ _target_: ocl.losses.SparsePenalty
+ linear_weight: 0.1
+ quadratic_weight: 0.0
+ quadratic_bias: 0.5
+ input_path: hard_keep_decision
+
+evaluation_metrics:
+ hard_keep_decision:
+ path: hard_keep_decision
+ reduction: sum
+
+ slots_keep_prob:
+ path: slots_keep_prob
+ reduction: mean
+
+ ami:
+ prediction_path: object_decoder.masks_as_image
+ target_path: input.instance_mask
+ foreground: true
+ convert_target_one_hot: true
+ ignore_overlaps: true
+ back_as_class: false
+
+ nmi:
+ prediction_path: object_decoder.masks_as_image
+ target_path: input.instance_mask
+ foreground: true
+ convert_target_one_hot: true
+ ignore_overlaps: true
+ back_as_class: false
+
+ purity:
+ prediction_path: object_decoder.masks_as_image
+ target_path: input.instance_mask
+ foreground: true
+ convert_target_one_hot: true
+ ignore_overlaps: true
+ back_as_class: false
+
+ precision:
+ prediction_path: object_decoder.masks_as_image
+ target_path: input.instance_mask
+ foreground: true
+ convert_target_one_hot: true
+ ignore_overlaps: true
+ back_as_class: false
+
+ recall:
+ prediction_path: object_decoder.masks_as_image
+ target_path: input.instance_mask
+ foreground: true
+ convert_target_one_hot: true
+ ignore_overlaps: true
+ back_as_class: false
+
+ f1:
+ prediction_path: object_decoder.masks_as_image
+ target_path: input.instance_mask
+ foreground: true
+ convert_target_one_hot: true
+ ignore_overlaps: true
+ back_as_class: false
+
diff --git a/configs/experiment/projects/bridging/dinosaur/movi_c_feat_rec_vitb16.yaml b/configs/experiment/projects/bridging/dinosaur/movi_c_feat_rec_vitb16.yaml
new file mode 100644
index 0000000..553241d
--- /dev/null
+++ b/configs/experiment/projects/bridging/dinosaur/movi_c_feat_rec_vitb16.yaml
@@ -0,0 +1,46 @@
+# @package _global_
+# ViT feature reconstruction on MOVI-E.
+defaults:
+ - /conditioning/random@models.conditioning
+ - /experiment/projects/bridging/dinosaur/_base_feature_recon
+ - /neural_networks/mlp@models.object_decoder.decoder
+ - /dataset: movi_c_image
+ - /experiment/projects/bridging/dinosaur/_preprocessing_movi_dino_feature_recon
+ - /experiment/projects/bridging/dinosaur/_metrics_clevr_patch
+ - _self_
+
+# The following parameters assume training on 8 GPUs, leading to an effective batch size of 64.
+trainer:
+ gpus: 8
+ max_steps: 500000
+ max_epochs: null
+ strategy: ddp
+
+dataset:
+ num_workers: 4
+ batch_size: 8
+
+models:
+ conditioning:
+ n_slots: 11
+ object_dim: 128
+
+ feature_extractor:
+ model_name: vit_base_patch16_224_dino
+ pretrained: true
+
+ perceptual_grouping:
+ input_dim: 768
+
+ object_decoder:
+ _target_: ocl.decoding.PatchDecoder
+ num_patches: 196
+ decoder:
+ features: [1024, 1024, 1024]
+
+ masks_as_image:
+ _target_: ocl.utils.resizing.Resize
+ input_path: object_decoder.masks
+ size: 128
+ resize_mode: bilinear
+ patch_mode: true
diff --git a/configs/experiment/projects/bridging/dinosaur/movi_c_feat_rec_vitb16_adaslot.yaml b/configs/experiment/projects/bridging/dinosaur/movi_c_feat_rec_vitb16_adaslot.yaml
new file mode 100644
index 0000000..10cd1d5
--- /dev/null
+++ b/configs/experiment/projects/bridging/dinosaur/movi_c_feat_rec_vitb16_adaslot.yaml
@@ -0,0 +1,73 @@
+# @package _global_
+# ViT feature reconstruction on MOVI-E.
+defaults:
+ - /conditioning/random@models.conditioning
+ - /experiment/projects/bridging/dinosaur/_base_feature_recon_gumbel
+ - /neural_networks/mlp@models.object_decoder.decoder
+ - /dataset: movi_c_image
+ - /experiment/projects/bridging/dinosaur/_preprocessing_movi_dino_feature_recon
+ - /experiment/projects/bridging/dinosaur/_metrics_clevr_patch
+ - /metrics/tensor_statistic@training_metrics.hard_keep_decision
+ - /metrics/tensor_statistic@training_metrics.slots_keep_prob
+ - _self_
+
+# The following parameters assume training on 8 GPUs, leading to an effective batch size of 64.
+trainer:
+ gpus: 8
+ max_steps: 500000
+ max_epochs: null
+ strategy: ddp
+
+dataset:
+ num_workers: 4
+ batch_size: 8
+
+models:
+ _target_: ocl.models.image_grouping_adaslot.GroupingImgGumbel
+ conditioning:
+ n_slots: 11
+ object_dim: 128
+
+ feature_extractor:
+ model_name: vit_base_patch16_224_dino
+ pretrained: true
+
+ perceptual_grouping:
+ input_dim: 768
+ low_bound: 0
+
+ object_decoder:
+ _target_: ocl.decoding.PatchDecoderGumbelV1
+ num_patches: 196
+ decoder:
+ features: [1024, 1024, 1024]
+ left_mask_path: None
+ mask_type: mask_normalized
+
+ masks_as_image:
+ _target_: ocl.utils.resizing.Resize
+ input_path: object_decoder.masks
+ size: 128
+ resize_mode: bilinear
+ patch_mode: true
+
+losses:
+ sparse_penalty:
+ _target_: ocl.losses.SparsePenalty
+ linear_weight: 0.1
+ quadratic_weight: 0.0
+ quadratic_bias: 0.5
+ input_path: hard_keep_decision
+
+# outputs["hard_keep_decision"] = perceptual_grouping_output["hard_keep_decision"]
+# outputs["slots_keep_prob"]
+training_metrics:
+ hard_keep_decision:
+ path: hard_keep_decision
+ reduction: sum
+
+ slots_keep_prob:
+ path: slots_keep_prob
+ reduction: mean
+
+load_model_weight: /home/ubuntu/GitLab/bags-of-tricks/object-centric-learning-models/outputs/projects/bridging/dinosaur/movi_c_feat_rec_vitb16.yaml/2023-04-28_15-20-10/lightning_logs/version_0/checkpoints/epoch=328-step=450401.ckpt
\ No newline at end of file
diff --git a/configs/experiment/projects/bridging/dinosaur/movi_c_feat_rec_vitb16_adaslot_eval.yaml b/configs/experiment/projects/bridging/dinosaur/movi_c_feat_rec_vitb16_adaslot_eval.yaml
new file mode 100644
index 0000000..57243bc
--- /dev/null
+++ b/configs/experiment/projects/bridging/dinosaur/movi_c_feat_rec_vitb16_adaslot_eval.yaml
@@ -0,0 +1,129 @@
+# @package _global_
+# ViT feature reconstruction on MOVI-E.
+defaults:
+ - /conditioning/random@models.conditioning
+ - /experiment/projects/bridging/dinosaur/_base_feature_recon_gumbel
+ - /neural_networks/mlp@models.object_decoder.decoder
+ - /dataset: movi_c_image
+ - /experiment/projects/bridging/dinosaur/_preprocessing_movi_dino_feature_recon
+ - /experiment/projects/bridging/dinosaur/_metrics_clevr_patch
+ - /metrics/tensor_statistic@evaluation_metrics.hard_keep_decision
+ - /metrics/tensor_statistic@evaluation_metrics.slots_keep_prob
+ - /metrics/ami_metric@evaluation_metrics.ami
+ - /metrics/nmi_metric@evaluation_metrics.nmi
+ - /metrics/purity_metric@evaluation_metrics.purity
+ - /metrics/precision_metric@evaluation_metrics.precision
+ - /metrics/recall_metric@evaluation_metrics.recall
+ - /metrics/f1_metric@evaluation_metrics.f1
+ - /metrics/mask_corloc_metric@evaluation_metrics.instance_mask_corloc
+ - _self_
+
+# The following parameters assume training on 8 GPUs, leading to an effective batch size of 64.
+trainer:
+ gpus: 8
+ max_steps: 500000
+ max_epochs: null
+ strategy: ddp
+
+dataset:
+ num_workers: 4
+ batch_size: 8
+
+models:
+ _target_: ocl.models.image_grouping_adaslot.GroupingImgGumbel
+ conditioning:
+ n_slots: 11
+ object_dim: 128
+
+ feature_extractor:
+ model_name: vit_base_patch16_224_dino
+ pretrained: true
+
+ perceptual_grouping:
+ input_dim: 768
+ low_bound: 0
+
+ object_decoder:
+ _target_: ocl.decoding.PatchDecoderGumbelV1
+ num_patches: 196
+ decoder:
+ features: [1024, 1024, 1024]
+ left_mask_path: None
+ mask_type: mask_normalized
+
+ masks_as_image:
+ _target_: ocl.utils.resizing.Resize
+ input_path: object_decoder.masks
+ size: 128
+ resize_mode: bilinear
+ patch_mode: true
+
+losses:
+ sparse_penalty:
+ _target_: ocl.losses.SparsePenalty
+ linear_weight: 0.1
+ quadratic_weight: 0.0
+ quadratic_bias: 0.5
+ input_path: hard_keep_decision
+
+evaluation_metrics:
+ hard_keep_decision:
+ path: hard_keep_decision
+ reduction: sum
+
+ slots_keep_prob:
+ path: slots_keep_prob
+ reduction: mean
+
+ ami:
+ prediction_path: masks_as_image
+ target_path: input.mask
+ foreground: true
+ convert_target_one_hot: false
+ ignore_overlaps: false
+
+ nmi:
+ prediction_path: masks_as_image
+ target_path: input.mask
+ foreground: true
+ convert_target_one_hot: false
+ ignore_overlaps: false
+
+ purity:
+ prediction_path: masks_as_image
+ target_path: input.mask
+ foreground: true
+ convert_target_one_hot: false
+ ignore_overlaps: false
+
+ precision:
+ prediction_path: masks_as_image
+ target_path: input.mask
+ foreground: true
+ convert_target_one_hot: false
+ ignore_overlaps: false
+
+ recall:
+ prediction_path: masks_as_image
+ target_path: input.mask
+ foreground: true
+ convert_target_one_hot: false
+ ignore_overlaps: false
+
+ f1:
+ prediction_path: masks_as_image
+ target_path: input.mask
+ foreground: true
+ convert_target_one_hot: false
+ ignore_overlaps: false
+
+ instance_mask_corloc:
+ prediction_path: masks_as_image
+ target_path: input.mask
+ use_threshold: False
+ ignore_background: True
+ ignore_overlaps: False
+
+plugins:
+ 02_sample_frames:
+ n_frames_per_video: 24
diff --git a/configs/experiment/projects/bridging/dinosaur/movi_e_feat_rec_vitb16.yaml b/configs/experiment/projects/bridging/dinosaur/movi_e_feat_rec_vitb16.yaml
new file mode 100644
index 0000000..8371a30
--- /dev/null
+++ b/configs/experiment/projects/bridging/dinosaur/movi_e_feat_rec_vitb16.yaml
@@ -0,0 +1,46 @@
+# @package _global_
+# ViT feature reconstruction on MOVI-E.
+defaults:
+ - /conditioning/random@models.conditioning
+ - /experiment/projects/bridging/dinosaur/_base_feature_recon
+ - /neural_networks/mlp@models.object_decoder.decoder
+ - /dataset: movi_e_image
+ - /experiment/projects/bridging/dinosaur/_preprocessing_movi_dino_feature_recon
+ - /experiment/projects/bridging/dinosaur/_metrics_clevr_patch
+ - _self_
+
+# The following parameters assume training on 8 GPUs, leading to an effective batch size of 64.
+trainer:
+ gpus: 8
+ max_steps: 100000
+ max_epochs: null
+ strategy: ddp
+
+dataset:
+ num_workers: 4
+ batch_size: 8
+
+models:
+ conditioning:
+ n_slots: 24
+ object_dim: 128
+
+ feature_extractor:
+ model_name: vit_base_patch16_224_dino
+ pretrained: true
+
+ perceptual_grouping:
+ input_dim: 768
+
+ object_decoder:
+ _target_: ocl.decoding.PatchDecoder
+ num_patches: 196
+ decoder:
+ features: [1024, 1024, 1024]
+
+ masks_as_image:
+ _target_: ocl.utils.resizing.Resize
+ input_path: object_decoder.masks
+ size: 128
+ resize_mode: bilinear
+ patch_mode: true
diff --git a/configs/experiment/projects/bridging/dinosaur/movi_e_feat_rec_vitb16_adaslot.yaml b/configs/experiment/projects/bridging/dinosaur/movi_e_feat_rec_vitb16_adaslot.yaml
new file mode 100644
index 0000000..a46ea8a
--- /dev/null
+++ b/configs/experiment/projects/bridging/dinosaur/movi_e_feat_rec_vitb16_adaslot.yaml
@@ -0,0 +1,73 @@
+# @package _global_
+# ViT feature reconstruction on MOVI-E.
+defaults:
+ - /conditioning/random@models.conditioning
+ - /experiment/projects/bridging/dinosaur/_base_feature_recon_gumbel
+ - /neural_networks/mlp@models.object_decoder.decoder
+ - /dataset: movi_e_image
+ - /experiment/projects/bridging/dinosaur/_preprocessing_movi_dino_feature_recon
+ - /experiment/projects/bridging/dinosaur/_metrics_clevr_patch
+ - /metrics/tensor_statistic@training_metrics.hard_keep_decision
+ - /metrics/tensor_statistic@training_metrics.slots_keep_prob
+ - _self_
+
+# The following parameters assume training on 8 GPUs, leading to an effective batch size of 64.
+trainer:
+ gpus: 8
+ max_steps: 500000
+ max_epochs: null
+ strategy: ddp
+
+dataset:
+ num_workers: 4
+ batch_size: 8
+
+models:
+ _target_: ocl.models.image_grouping_adaslot.GroupingImgGumbel
+ conditioning:
+ n_slots: 24
+ object_dim: 128
+
+ feature_extractor:
+ model_name: vit_base_patch16_224_dino
+ pretrained: true
+
+ perceptual_grouping:
+ input_dim: 768
+ low_bound: 0
+
+ object_decoder:
+ _target_: ocl.decoding.PatchDecoderGumbelV1
+ num_patches: 196
+ decoder:
+ features: [1024, 1024, 1024]
+ left_mask_path: None
+ mask_type: mask_normalized
+
+ masks_as_image:
+ _target_: ocl.utils.resizing.Resize
+ input_path: object_decoder.masks
+ size: 128
+ resize_mode: bilinear
+ patch_mode: true
+
+losses:
+ sparse_penalty:
+ _target_: ocl.losses.SparsePenalty
+ linear_weight: 0.1
+ quadratic_weight: 0.0
+ quadratic_bias: 0.5
+ input_path: hard_keep_decision
+
+# outputs["hard_keep_decision"] = perceptual_grouping_output["hard_keep_decision"]
+# outputs["slots_keep_prob"]
+training_metrics:
+ hard_keep_decision:
+ path: hard_keep_decision
+ reduction: sum
+
+ slots_keep_prob:
+ path: slots_keep_prob
+ reduction: mean
+
+# load_model_weight: /home/ubuntu/GitLab/bags-of-tricks/object-centric-learning-models/outputs/projects/bridging/dinosaur/movi_e_feat_rec_vitb16.yaml/2023-04-25_17-13-38/lightning_logs/version_0/checkpoints/epoch=86-step=119190.ckpt
\ No newline at end of file
diff --git a/configs/experiment/projects/bridging/dinosaur/movi_e_feat_rec_vitb16_adaslot_eval.yaml b/configs/experiment/projects/bridging/dinosaur/movi_e_feat_rec_vitb16_adaslot_eval.yaml
new file mode 100644
index 0000000..cd2e001
--- /dev/null
+++ b/configs/experiment/projects/bridging/dinosaur/movi_e_feat_rec_vitb16_adaslot_eval.yaml
@@ -0,0 +1,129 @@
+# @package _global_
+# ViT feature reconstruction on MOVI-E.
+defaults:
+ - /conditioning/random@models.conditioning
+ - /experiment/projects/bridging/dinosaur/_base_feature_recon_gumbel
+ - /neural_networks/mlp@models.object_decoder.decoder
+ - /dataset: movi_e_image
+ - /experiment/projects/bridging/dinosaur/_preprocessing_movi_dino_feature_recon
+ - /experiment/projects/bridging/dinosaur/_metrics_clevr_patch
+ - /metrics/tensor_statistic@evaluation_metrics.hard_keep_decision
+ - /metrics/tensor_statistic@evaluation_metrics.slots_keep_prob
+ - /metrics/ami_metric@evaluation_metrics.ami
+ - /metrics/nmi_metric@evaluation_metrics.nmi
+ - /metrics/purity_metric@evaluation_metrics.purity
+ - /metrics/precision_metric@evaluation_metrics.precision
+ - /metrics/recall_metric@evaluation_metrics.recall
+ - /metrics/f1_metric@evaluation_metrics.f1
+ - /metrics/mask_corloc_metric@evaluation_metrics.instance_mask_corloc
+ - _self_
+
+# The following parameters assume training on 8 GPUs, leading to an effective batch size of 64.
+trainer:
+ gpus: 8
+ max_steps: 500000
+ max_epochs: null
+ strategy: ddp
+
+dataset:
+ num_workers: 4
+ batch_size: 8
+
+models:
+ _target_: ocl.models.image_grouping_adaslot.GroupingImgGumbel
+ conditioning:
+ n_slots: 24
+ object_dim: 128
+
+ feature_extractor:
+ model_name: vit_base_patch16_224_dino
+ pretrained: true
+
+ perceptual_grouping:
+ input_dim: 768
+ low_bound: 0
+
+ object_decoder:
+ _target_: ocl.decoding.PatchDecoderGumbelV1
+ num_patches: 196
+ decoder:
+ features: [1024, 1024, 1024]
+ left_mask_path: None
+ mask_type: mask_normalized
+
+ masks_as_image:
+ _target_: ocl.utils.resizing.Resize
+ input_path: object_decoder.masks
+ size: 128
+ resize_mode: bilinear
+ patch_mode: true
+
+losses:
+ sparse_penalty:
+ _target_: ocl.losses.SparsePenalty
+ linear_weight: 0.1
+ quadratic_weight: 0.0
+ quadratic_bias: 0.5
+ input_path: hard_keep_decision
+
+evaluation_metrics:
+ hard_keep_decision:
+ path: hard_keep_decision
+ reduction: sum
+
+ slots_keep_prob:
+ path: slots_keep_prob
+ reduction: mean
+
+ ami:
+ prediction_path: masks_as_image
+ target_path: input.mask
+ foreground: true
+ convert_target_one_hot: false
+ ignore_overlaps: false
+
+ nmi:
+ prediction_path: masks_as_image
+ target_path: input.mask
+ foreground: true
+ convert_target_one_hot: false
+ ignore_overlaps: false
+
+ purity:
+ prediction_path: masks_as_image
+ target_path: input.mask
+ foreground: true
+ convert_target_one_hot: false
+ ignore_overlaps: false
+
+ precision:
+ prediction_path: masks_as_image
+ target_path: input.mask
+ foreground: true
+ convert_target_one_hot: false
+ ignore_overlaps: false
+
+ recall:
+ prediction_path: masks_as_image
+ target_path: input.mask
+ foreground: true
+ convert_target_one_hot: false
+ ignore_overlaps: false
+
+ f1:
+ prediction_path: masks_as_image
+ target_path: input.mask
+ foreground: true
+ convert_target_one_hot: false
+ ignore_overlaps: false
+
+ instance_mask_corloc:
+ prediction_path: masks_as_image
+ target_path: input.mask
+ use_threshold: False
+ ignore_background: True
+ ignore_overlaps: False
+
+plugins:
+ 02_sample_frames:
+ n_frames_per_video: 24
diff --git a/configs/experiment/slot_attention/_base.yaml b/configs/experiment/slot_attention/_base.yaml
new file mode 100644
index 0000000..163e758
--- /dev/null
+++ b/configs/experiment/slot_attention/_base.yaml
@@ -0,0 +1,88 @@
+# @package _global_
+# Default parameters for slot attention.
+defaults:
+ - /experiment/_output_path
+ - /training_config
+ - /feature_extractor/slot_attention@models.feature_extractor
+ - /conditioning/random@models.conditioning
+ - /perceptual_grouping/slot_attention@models.perceptual_grouping
+ - /plugins/optimization@plugins.optimize_parameters
+ - /optimizers/adam@plugins.optimize_parameters.optimizer
+ - /lr_schedulers/exponential_decay@plugins.optimize_parameters.lr_scheduler
+ - _self_
+
+models:
+ conditioning:
+ object_dim: 64
+
+ perceptual_grouping:
+ feature_dim: 64
+ object_dim: ${..conditioning.object_dim}
+ kvq_dim: 128
+ positional_embedding:
+ _target_: ocl.neural_networks.wrappers.Sequential
+ _args_:
+ - _target_: ocl.neural_networks.positional_embedding.SoftPositionEmbed
+ n_spatial_dims: 2
+ feature_dim: 64
+ - _target_: ocl.neural_networks.build_two_layer_mlp
+ input_dim: 64
+ output_dim: 64
+ hidden_dim: 128
+ initial_layer_norm: true
+ residual: false
+ ff_mlp:
+ _target_: ocl.neural_networks.build_two_layer_mlp
+ input_dim: 64
+ output_dim: 64
+ hidden_dim: 128
+ initial_layer_norm: true
+ residual: true
+
+ object_decoder:
+ _target_: ocl.decoding.SlotAttentionDecoder
+ object_features_path: perceptual_grouping.objects
+ decoder:
+ _target_: ocl.decoding.get_slotattention_decoder_backbone
+ object_dim: ${models.perceptual_grouping.object_dim}
+ positional_embedding:
+ _target_: ocl.neural_networks.positional_embedding.SoftPositionEmbed
+ n_spatial_dims: 2
+ feature_dim: ${models.perceptual_grouping.object_dim}
+ cnn_channel_order: true
+
+plugins:
+ optimize_parameters:
+ optimizer:
+ lr: 0.0004
+ lr_scheduler:
+ decay_rate: 0.5
+ decay_steps: 100000
+ warmup_steps: 10000
+
+losses:
+ mse:
+ _target_: ocl.losses.ReconstructionLoss
+ loss_type: mse_sum
+ input_path: object_decoder.reconstruction
+ target_path: input.image
+
+visualizations:
+ input:
+ _target_: ocl.visualizations.Image
+ denormalization: "${lambda_fn:'lambda t: t * 0.5 + 0.5'}"
+ image_path: input.image
+ reconstruction:
+ _target_: ocl.visualizations.Image
+ denormalization: ${..input.denormalization}
+ image_path: object_decoder.reconstruction
+ objects:
+ _target_: ocl.visualizations.VisualObject
+ denormalization: ${..input.denormalization}
+ object_path: object_decoder.object_reconstructions
+ mask_path: object_decoder.masks
+ pred_segmentation:
+ _target_: ocl.visualizations.Segmentation
+ denormalization: ${..input.denormalization}
+ image_path: input.image
+ mask_path: object_decoder.masks
diff --git a/configs/experiment/slot_attention/_base_gumbel.yaml b/configs/experiment/slot_attention/_base_gumbel.yaml
new file mode 100644
index 0000000..2c9c6e0
--- /dev/null
+++ b/configs/experiment/slot_attention/_base_gumbel.yaml
@@ -0,0 +1,96 @@
+# @package _global_
+# Default parameters for slot attention.
+defaults:
+ - /experiment/_output_path
+ - /training_config
+ - /feature_extractor/slot_attention@models.feature_extractor
+ - /conditioning/random@models.conditioning
+ - /perceptual_grouping/slot_attention_gumbel_v1@models.perceptual_grouping
+ - /plugins/optimization@plugins.optimize_parameters
+ - /optimizers/adam@plugins.optimize_parameters.optimizer
+ - /lr_schedulers/exponential_decay@plugins.optimize_parameters.lr_scheduler
+ - _self_
+
+models:
+ conditioning:
+ object_dim: 64
+
+ perceptual_grouping:
+ feature_dim: 64
+ object_dim: ${..conditioning.object_dim}
+ kvq_dim: 128
+ positional_embedding:
+ _target_: ocl.neural_networks.wrappers.Sequential
+ _args_:
+ - _target_: ocl.neural_networks.positional_embedding.SoftPositionEmbed
+ n_spatial_dims: 2
+ feature_dim: 64
+ - _target_: ocl.neural_networks.build_two_layer_mlp
+ input_dim: 64
+ output_dim: 64
+ hidden_dim: 128
+ initial_layer_norm: true
+ residual: false
+ ff_mlp:
+ _target_: ocl.neural_networks.build_two_layer_mlp
+ input_dim: 64
+ output_dim: 64
+ hidden_dim: 128
+ initial_layer_norm: true
+ residual: true
+
+ single_gumbel_score_network:
+ _target_: ocl.neural_networks.build_two_layer_mlp
+ input_dim: ${..object_dim}
+ output_dim: 2
+ hidden_dim: "${eval_lambda:'lambda dim: 4 * dim', ${..object_dim}}"
+ initial_layer_norm: true
+ residual: false
+
+ object_decoder:
+ _target_: ocl.decoding.SlotAttentionDecoder
+ object_features_path: perceptual_grouping.objects
+ decoder:
+ _target_: ocl.decoding.get_slotattention_decoder_backbone
+ object_dim: ${models.perceptual_grouping.object_dim}
+ positional_embedding:
+ _target_: ocl.neural_networks.positional_embedding.SoftPositionEmbed
+ n_spatial_dims: 2
+ feature_dim: ${models.perceptual_grouping.object_dim}
+ cnn_channel_order: true
+
+plugins:
+ optimize_parameters:
+ optimizer:
+ lr: 0.0004
+ lr_scheduler:
+ decay_rate: 0.5
+ decay_steps: 100000
+ warmup_steps: 10000
+
+losses:
+ mse:
+ _target_: ocl.losses.ReconstructionLoss
+ loss_type: mse_sum
+ input_path: object_decoder.reconstruction
+ target_path: input.image
+
+visualizations:
+ input:
+ _target_: ocl.visualizations.Image
+ denormalization: "${lambda_fn:'lambda t: t * 0.5 + 0.5'}"
+ image_path: input.image
+ reconstruction:
+ _target_: ocl.visualizations.Image
+ denormalization: ${..input.denormalization}
+ image_path: object_decoder.reconstruction
+ objects:
+ _target_: ocl.visualizations.VisualObject
+ denormalization: ${..input.denormalization}
+ object_path: object_decoder.object_reconstructions
+ mask_path: object_decoder.masks
+ pred_segmentation:
+ _target_: ocl.visualizations.Segmentation
+ denormalization: ${..input.denormalization}
+ image_path: input.image
+ mask_path: object_decoder.masks
diff --git a/configs/experiment/slot_attention/_base_large.yaml b/configs/experiment/slot_attention/_base_large.yaml
new file mode 100644
index 0000000..03014c9
--- /dev/null
+++ b/configs/experiment/slot_attention/_base_large.yaml
@@ -0,0 +1,93 @@
+# @package _global_
+# Default parameters for slot attention on resolution 128x128 with a ResNet encoder
+defaults:
+ - /experiment/_output_path
+ - /training_config
+ - /feature_extractor/timm_model@models.feature_extractor
+ - /perceptual_grouping/slot_attention@models.perceptual_grouping
+ - /plugins/optimization@plugins.optimize_parameters
+ - /optimizers/adam@plugins.optimize_parameters.optimizer
+ - /lr_schedulers/cosine_annealing@plugins.optimize_parameters.lr_scheduler
+ - _self_
+
+models:
+ feature_extractor:
+ model_name: resnet34_savi
+ feature_level: 4
+ pretrained: false
+ freeze: false
+
+ perceptual_grouping:
+ feature_dim: ${models.perceptual_grouping.object_dim}
+ object_dim: ${models.conditioning.object_dim}
+ kvq_dim: ${models.perceptual_grouping.object_dim}
+ positional_embedding:
+ _target_: ocl.neural_networks.wrappers.Sequential
+ _args_:
+ - _target_: ocl.neural_networks.positional_embedding.SoftPositionEmbed
+ n_spatial_dims: 2
+ feature_dim: 512
+ savi_style: true
+ - _target_: ocl.neural_networks.build_two_layer_mlp
+ input_dim: 512
+ output_dim: ${models.perceptual_grouping.object_dim}
+ hidden_dim: ${models.perceptual_grouping.object_dim}
+ initial_layer_norm: true
+ ff_mlp:
+ _target_: ocl.neural_networks.build_two_layer_mlp
+ input_dim: ${models.perceptual_grouping.object_dim}
+ output_dim: ${models.perceptual_grouping.object_dim}
+ hidden_dim: "${eval_lambda:'lambda dim: 2 * dim', ${.input_dim}}"
+ initial_layer_norm: true
+ residual: true
+
+ object_decoder:
+ _target_: ocl.decoding.SlotAttentionDecoder
+ final_activation: tanh
+ decoder:
+ _target_: ocl.decoding.get_savi_decoder_backbone
+ object_dim: ${models.perceptual_grouping.object_dim}
+ larger_input_arch: true
+ channel_multiplier: 1
+ positional_embedding:
+ _target_: ocl.neural_networks.positional_embedding.SoftPositionEmbed
+ n_spatial_dims: 2
+ feature_dim: ${models.perceptual_grouping.object_dim}
+ cnn_channel_order: true
+ savi_style: true
+ object_features_path: perceptual_grouping.objects
+
+plugins:
+ optimize_parameters:
+ optimizer:
+ lr: 0.0002
+ lr_scheduler:
+ warmup_steps: 2500
+ T_max: ${trainer.max_steps}
+
+losses:
+ mse:
+ _target_: ocl.losses.ReconstructionLoss
+ loss_type: mse
+ input_path: object_decoder.reconstruction
+ target_path: input.image
+
+visualizations:
+ input:
+ _target_: ocl.visualizations.Image
+ denormalization: "${lambda_fn:'lambda t: t * 0.5 + 0.5'}"
+ image_path: input.image
+ reconstruction:
+ _target_: ocl.visualizations.Image
+ denormalization: ${..input.denormalization}
+ image_path: object_decoder.reconstruction
+ objects:
+ _target_: ocl.visualizations.VisualObject
+ denormalization: ${..input.denormalization}
+ object_path: object_decoder.object_reconstructions
+ mask_path: object_decoder.masks
+ pred_segmentation:
+ _target_: ocl.visualizations.Segmentation
+ denormalization: ${..input.denormalization}
+ image_path: input.image
+ mask_path: object_decoder.masks
diff --git a/configs/experiment/slot_attention/_metrics_clevr.yaml b/configs/experiment/slot_attention/_metrics_clevr.yaml
new file mode 100644
index 0000000..a0c9392
--- /dev/null
+++ b/configs/experiment/slot_attention/_metrics_clevr.yaml
@@ -0,0 +1,23 @@
+# @package _global_
+# Metrics for CLEVR-like datasets
+defaults:
+ - /metrics/ari_metric@evaluation_metrics.ari
+ - /metrics/average_best_overlap_metric@evaluation_metrics.abo
+ - /metrics/ari_metric@evaluation_metrics.ari_bg
+ - /metrics/average_best_overlap_metric@evaluation_metrics.abo_bg
+evaluation_metrics:
+ ari:
+ prediction_path: object_decoder.masks
+ target_path: input.mask
+ abo:
+ prediction_path: object_decoder.masks
+ target_path: input.mask
+ ignore_background: true
+ ari_bg:
+ prediction_path: object_decoder.masks
+ target_path: input.mask
+ foreground: false
+ abo_bg:
+ prediction_path: object_decoder.masks
+ target_path: input.mask
+ ignore_background: false
\ No newline at end of file
diff --git a/configs/experiment/slot_attention/_metrics_coco.yaml b/configs/experiment/slot_attention/_metrics_coco.yaml
new file mode 100644
index 0000000..9184c83
--- /dev/null
+++ b/configs/experiment/slot_attention/_metrics_coco.yaml
@@ -0,0 +1,36 @@
+# @package _global_
+# Metrics for COCO-like datasets
+defaults:
+ - /metrics/ari_metric@evaluation_metrics.instance_mask_ari
+ - /metrics/unsupervised_mask_iou_metric@evaluation_metrics.instance_mask_iou
+ - /metrics/unsupervised_mask_iou_metric@evaluation_metrics.segmentation_mask_iou
+ - /metrics/average_best_overlap_metric@evaluation_metrics.instance_mask_abo
+ - /metrics/average_best_overlap_metric@evaluation_metrics.segmentation_mask_abo
+ - /metrics/mask_corloc_metric@evaluation_metrics.instance_mask_corloc
+
+evaluation_metrics:
+ instance_mask_ari:
+ prediction_path: object_decoder.masks
+ target_path: input.instance_mask
+ foreground: false
+ ignore_overlaps: true
+ convert_target_one_hot: true
+ instance_mask_iou:
+ prediction_path: object_decoder.masks
+ target_path: input.instance_mask
+ ignore_overlaps: true
+ segmentation_mask_iou:
+ prediction_path: object_decoder.masks
+ target_path: input.segmentation_mask
+ instance_mask_abo:
+ prediction_path: object_decoder.masks
+ target_path: input.instance_mask
+ ignore_overlaps: true
+ segmentation_mask_abo:
+ prediction_path: object_decoder.masks
+ target_path: input.segmentation_mask
+ instance_mask_corloc:
+ prediction_path: object_decoder.masks
+ target_path: input.instance_mask
+ use_threshold: False
+ ignore_overlaps: true
diff --git a/configs/experiment/slot_attention/_preprocessing_cater.yaml b/configs/experiment/slot_attention/_preprocessing_cater.yaml
new file mode 100644
index 0000000..cc008a3
--- /dev/null
+++ b/configs/experiment/slot_attention/_preprocessing_cater.yaml
@@ -0,0 +1,34 @@
+# @package _global_
+defaults:
+ - /plugins/multi_element_preprocessing@plugins.03_preprocessing
+ - _self_
+
+
+plugins:
+ 03_preprocessing:
+ training_transforms:
+ image:
+ _target_: torchvision.transforms.Compose
+ transforms:
+ - _target_: torchvision.transforms.ToTensor
+ - _target_: torchvision.transforms.Resize
+ size: 128
+ - _target_: torchvision.transforms.Normalize
+ mean: [0.5, 0.5, 0.5]
+ std: [0.5, 0.5, 0.5]
+ evaluation_transforms:
+ image:
+ _target_: torchvision.transforms.Compose
+ transforms:
+ - _target_: torchvision.transforms.ToTensor
+ - _target_: torchvision.transforms.Resize
+ size: 128
+ - _target_: torchvision.transforms.Normalize
+ mean: [0.5, 0.5, 0.5]
+ std: [0.5, 0.5, 0.5]
+ mask:
+ _target_: torchvision.transforms.Compose
+ transforms:
+ - _target_: ocl.preprocessing.MultiMaskToTensor
+ - _target_: ocl.preprocessing.ResizeNearestExact
+ size: 128
diff --git a/configs/experiment/slot_attention/_preprocessing_clevr.yaml b/configs/experiment/slot_attention/_preprocessing_clevr.yaml
new file mode 100644
index 0000000..9e1d55a
--- /dev/null
+++ b/configs/experiment/slot_attention/_preprocessing_clevr.yaml
@@ -0,0 +1,38 @@
+# @package _global_
+defaults:
+ - /plugins/multi_element_preprocessing@plugins.03_preprocessing
+
+plugins:
+ 03_preprocessing:
+ training_transforms:
+ image:
+ _target_: torchvision.transforms.Compose
+ transforms:
+ - _target_: torchvision.transforms.ToTensor
+ - _target_: torchvision.transforms.CenterCrop
+ size: [192, 192]
+ - _target_: torchvision.transforms.Resize
+ size: 128
+ - _target_: torchvision.transforms.Normalize
+ mean: [0.5, 0.5, 0.5]
+ std: [0.5, 0.5, 0.5]
+ evaluation_transforms:
+ image:
+ _target_: torchvision.transforms.Compose
+ transforms:
+ - _target_: torchvision.transforms.ToTensor
+ - _target_: torchvision.transforms.CenterCrop
+ size: [192, 192]
+ - _target_: torchvision.transforms.Resize
+ size: 128
+ - _target_: torchvision.transforms.Normalize
+ mean: [0.5, 0.5, 0.5]
+ std: [0.5, 0.5, 0.5]
+ mask:
+ _target_: torchvision.transforms.Compose
+ transforms:
+ - _target_: ocl.preprocessing.MaskToTensor
+ - _target_: torchvision.transforms.CenterCrop
+ size: [192, 192]
+ - _target_: ocl.preprocessing.ResizeNearestExact
+ size: 128
diff --git a/configs/experiment/slot_attention/_preprocessing_clevr_no_norm.yaml b/configs/experiment/slot_attention/_preprocessing_clevr_no_norm.yaml
new file mode 100644
index 0000000..ca27f74
--- /dev/null
+++ b/configs/experiment/slot_attention/_preprocessing_clevr_no_norm.yaml
@@ -0,0 +1,32 @@
+# @package _global_
+defaults:
+ - /plugins/multi_element_preprocessing@plugins.03_preprocessing
+
+plugins:
+ 03_preprocessing:
+ training_transforms:
+ image:
+ _target_: torchvision.transforms.Compose
+ transforms:
+ - _target_: torchvision.transforms.ToTensor
+ - _target_: torchvision.transforms.CenterCrop
+ size: [192, 192]
+ - _target_: torchvision.transforms.Resize
+ size: 128
+ evaluation_transforms:
+ image:
+ _target_: torchvision.transforms.Compose
+ transforms:
+ - _target_: torchvision.transforms.ToTensor
+ - _target_: torchvision.transforms.CenterCrop
+ size: [192, 192]
+ - _target_: torchvision.transforms.Resize
+ size: 128
+ mask:
+ _target_: torchvision.transforms.Compose
+ transforms:
+ - _target_: ocl.preprocessing.MaskToTensor
+ - _target_: torchvision.transforms.CenterCrop
+ size: [192, 192]
+ - _target_: ocl.preprocessing.ResizeNearestExact
+ size: 128
diff --git a/configs/experiment/slot_attention/clevr10.yaml b/configs/experiment/slot_attention/clevr10.yaml
new file mode 100644
index 0000000..638d722
--- /dev/null
+++ b/configs/experiment/slot_attention/clevr10.yaml
@@ -0,0 +1,21 @@
+# @package _global_
+defaults:
+ - /experiment/slot_attention/_base
+ - /dataset: clevr
+ - /experiment/slot_attention/_preprocessing_clevr
+ - /experiment/slot_attention/_metrics_clevr
+ - _self_
+
+# The following parameters assume training on 8 GPUs, leading to an effective batch size of 64.
+trainer:
+ gpus: 8
+ max_steps: 500000
+ max_epochs: null
+ strategy: ddp
+dataset:
+ num_workers: 4
+ batch_size: 8
+
+models:
+ conditioning:
+ n_slots: 11
diff --git a/configs/experiment/slot_attention/clevr10_adaslot.yaml b/configs/experiment/slot_attention/clevr10_adaslot.yaml
new file mode 100644
index 0000000..fa1cc99
--- /dev/null
+++ b/configs/experiment/slot_attention/clevr10_adaslot.yaml
@@ -0,0 +1,51 @@
+# @package _global_
+defaults:
+ - /experiment/slot_attention/_base_gumbel
+ - /dataset: clevr
+ - /experiment/slot_attention/_preprocessing_clevr
+ - /experiment/slot_attention/_metrics_clevr
+ - /metrics/tensor_statistic@training_metrics.hard_keep_decision
+ - /metrics/tensor_statistic@training_metrics.slots_keep_prob
+ - _self_
+
+# The following parameters assume training on 8 GPUs, leading to an effective batch size of 64.
+trainer:
+ gpus: 8
+ max_steps: 500000
+ max_epochs: null
+ strategy: ddp
+dataset:
+ num_workers: 4
+ batch_size: 8
+
+models:
+ _target_: ocl.models.image_grouping_adaslot_pixel.GroupingImgGumbel
+ conditioning:
+ n_slots: 11
+
+ perceptual_grouping:
+ low_bound: 0
+
+ object_decoder:
+ _target_: ocl.decoding.SlotAttentionDecoderGumbel
+ left_mask_path: None
+ mask_type: mask_normalized
+
+losses:
+ sparse_penalty:
+ _target_: ocl.losses.SparsePenalty
+ linear_weight: 10
+ quadratic_weight: 0.0
+ quadratic_bias: 0.5
+ input_path: hard_keep_decision
+
+training_metrics:
+ hard_keep_decision:
+ path: hard_keep_decision
+ reduction: sum
+
+ slots_keep_prob:
+ path: slots_keep_prob
+ reduction: mean
+
+load_model_weight: /home/ubuntu/GitLab/bags-of-tricks/object-centric-learning-models/outputs/slot_attention/clevr10.yaml/2023-05-11_11-51-55/lightning_logs/version_0/checkpoints/epoch=457-step=500000.ckpt
\ No newline at end of file
diff --git a/configs/experiment/slot_attention/clevr10_adaslot_eval.yaml b/configs/experiment/slot_attention/clevr10_adaslot_eval.yaml
new file mode 100644
index 0000000..7723ae7
--- /dev/null
+++ b/configs/experiment/slot_attention/clevr10_adaslot_eval.yaml
@@ -0,0 +1,105 @@
+# @package _global_
+defaults:
+ - /experiment/slot_attention/_base_gumbel
+ - /dataset: clevr
+ - /experiment/slot_attention/_preprocessing_clevr
+ - /experiment/slot_attention/_metrics_clevr
+ - /metrics/tensor_statistic@evaluation_metrics.hard_keep_decision
+ - /metrics/tensor_statistic@evaluation_metrics.slots_keep_prob
+ - /metrics/ami_metric@evaluation_metrics.ami
+ - /metrics/nmi_metric@evaluation_metrics.nmi
+ - /metrics/purity_metric@evaluation_metrics.purity
+ - /metrics/precision_metric@evaluation_metrics.precision
+ - /metrics/recall_metric@evaluation_metrics.recall
+ - /metrics/f1_metric@evaluation_metrics.f1
+ - /metrics/mask_corloc_metric@evaluation_metrics.instance_mask_corloc
+ - _self_
+
+# The following parameters assume training on 8 GPUs, leading to an effective batch size of 64.
+trainer:
+ gpus: 8
+ max_steps: 500000
+ max_epochs: null
+ strategy: ddp
+dataset:
+ num_workers: 4
+ batch_size: 8
+
+models:
+ _target_: ocl.models.image_grouping_adaslot_pixel.GroupingImgGumbel
+ conditioning:
+ n_slots: 11
+
+ perceptual_grouping:
+ low_bound: 0
+
+ object_decoder:
+ _target_: ocl.decoding.SlotAttentionDecoderGumbel
+ left_mask_path: None
+ mask_type: mask_normalized
+
+losses:
+ sparse_penalty:
+ _target_: ocl.losses.SparsePenalty
+ linear_weight: 10
+ quadratic_weight: 0.0
+ quadratic_bias: 0.5
+ input_path: hard_keep_decision
+
+evaluation_metrics:
+ hard_keep_decision:
+ path: hard_keep_decision
+ reduction: sum
+
+ slots_keep_prob:
+ path: slots_keep_prob
+ reduction: mean
+
+ ami:
+ prediction_path: object_decoder.masks
+ target_path: input.mask
+ foreground: true
+ convert_target_one_hot: false
+ ignore_overlaps: false
+
+ nmi:
+ prediction_path: object_decoder.masks
+ target_path: input.mask
+ foreground: true
+ convert_target_one_hot: false
+ ignore_overlaps: false
+
+ purity:
+ prediction_path: object_decoder.masks
+ target_path: input.mask
+ foreground: true
+ convert_target_one_hot: false
+ ignore_overlaps: false
+
+ precision:
+ prediction_path: object_decoder.masks
+ target_path: input.mask
+ foreground: true
+ convert_target_one_hot: false
+ ignore_overlaps: false
+
+ recall:
+ prediction_path: object_decoder.masks
+ target_path: input.mask
+ foreground: true
+ convert_target_one_hot: false
+ ignore_overlaps: false
+
+ f1:
+ prediction_path: object_decoder.masks
+ target_path: input.mask
+ foreground: true
+ convert_target_one_hot: false
+ ignore_overlaps: false
+
+ instance_mask_corloc:
+ prediction_path: object_decoder.masks
+ target_path: input.mask
+ use_threshold: False
+ ignore_background: True
+ ignore_overlaps: False
diff --git a/framework.png b/framework.png
new file mode 100644
index 0000000..e42d738
Binary files /dev/null and b/framework.png differ
diff --git a/ocl/__init__.py b/ocl/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/ocl/base.py b/ocl/base.py
new file mode 100644
index 0000000..7549fd8
--- /dev/null
+++ b/ocl/base.py
@@ -0,0 +1,279 @@
+import abc
+import dataclasses
+import itertools
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import pluggy
+import torch
+from torch import nn
+from torchtyping import TensorType
+
+PluggyHookRelay = pluggy._hooks._HookRelay # Type alias for more readable function signatures
+
+ConditioningOutput = TensorType["batch_size", "n_objects", "object_dim"] # noqa: F821
+
+
+class Conditioning(nn.Module, metaclass=abc.ABCMeta):
+ """Base class for conditioning perceptual grouping."""
+
+ @abc.abstractmethod
+ def forward(self, *args) -> ConditioningOutput:
+ pass
+
+
+@dataclasses.dataclass
+class FrameFeatures:
+ """Features associated with a single frame."""
+
+ features: TensorType["batch_size", "n_spatial_features", "feature_dim"] # noqa: F821
+ positions: TensorType["n_spatial_features", "spatial_dims"] # noqa: F821
+
+
+@dataclasses.dataclass
+class FeatureExtractorOutput:
+ """Output of feature extractor."""
+
+ features: TensorType["batch_size", "frames", "n_spatial_features", "feature_dim"] # noqa: F821
+ positions: TensorType["n_spatial_features", "spatial_dims"] # noqa: F821
+ aux_features: Optional[Dict[str, torch.Tensor]] = None
+
+ def __iter__(self):
+ """Iterate over features and positions per frame."""
+ for frame_features in torch.split(self.features, 1, dim=1):
+ yield FrameFeatures(frame_features.squeeze(1), self.positions)
+
+
+class FeatureExtractor(nn.Module, metaclass=abc.ABCMeta):
+ """Abstract base class for Feature Extractors.
+
+ We expect that the forward method returns a flattened representation of the features, to make
+ outputs consistent and not dependent on equal spacing or the dimensionality of the spatial
+ information.
+ """
+
+ @property
+ @abc.abstractmethod
+ def feature_dim(self):
+ """Get dimensionality of the features.
+
+ Returns:
+ int: The dimensionality of the features.
+ """
+
+ @abc.abstractmethod
+ def forward(self, inputs: torch.Tensor) -> FeatureExtractorOutput:
+ pass
+
+
+@dataclasses.dataclass
+class PerceptualGroupingOutput:
+ """Output of a perceptual grouping algorithm."""
+
+ objects: TensorType["batch_size", "n_objects", "object_dim"] # noqa: F821
+ is_empty: Optional[TensorType["batch_size", "n_objects"]] = None # noqa: F821
+ feature_attributions: Optional[
+ TensorType["batch_size", "n_objects", "n_spatial_features"] # noqa: F821
+ ] = None
+
+
+class PerceptualGrouping(nn.Module, metaclass=abc.ABCMeta):
+ """Abstract base class of a perceptual grouping algorithm."""
+
+ @abc.abstractmethod
+ def forward(self, extracted_features: FeatureExtractorOutput) -> PerceptualGroupingOutput:
+ pass
+
+ @property
+ @abc.abstractmethod
+ def object_dim(self):
+ pass
+
+
+class Instances:
+ """Modified from Detectron2 (https://github.com/facebookresearch/detectron2).
+
+ This class represents a list of instances in an image.
+ It stores the attributes of instances (e.g., boxes, masks, labels, scores) as "fields".
+ All fields must have the same ``__len__`` which is the number of instances.
+
+ All other (non-field) attributes of this class are considered private:
+ they must start with '_' and are not modifiable by a user.
+
+ Some basic usage:
+
+ 1. Set/get/check a field:
+
+ .. code-block:: python
+
+ instances.gt_boxes = Boxes(...)
+ print(instances.pred_masks) # a tensor of shape (N, H, W)
+ print('gt_masks' in instances)
+
+ 2. ``len(instances)`` returns the number of instances
+ 3. Indexing: ``instances[indices]`` will apply the indexing on all the fields
+ and returns a new :class:`Instances`.
+ Typically, ``indices`` is a integer vector of indices,
+ or a binary mask of length ``num_instances``
+
+ .. code-block:: python
+
+ category_3_detections = instances[instances.pred_classes == 3]
+ confident_detections = instances[instances.scores > 0.9]
+ """
+
+ def __init__(self, image_size: Tuple[int, int], **kwargs: Any):
+ """Init function.
+
+ Args:
+ image_size (height, width): the spatial size of the image.
+ kwargs: fields to add to this `Instances`.
+ """
+ self._image_size = image_size
+ self._fields: Dict[str, Any] = {}
+ for k, v in kwargs.items():
+ self.set(k, v)
+
+ @property
+ def image_size(self) -> Tuple[int, int]:
+ return self._image_size
+
+ def __setattr__(self, name: str, val: Any) -> None:
+ if name.startswith("_"):
+ super().__setattr__(name, val)
+ else:
+ self.set(name, val)
+
+ def __getattr__(self, name: str) -> Any:
+ if name == "_fields" or name not in self._fields:
+ raise AttributeError("Cannot find field '{}' in the given Instances!".format(name))
+ return self._fields[name]
+
+ def set(self, name: str, value: Any) -> None:
+ """Set the field named `name` to `value`.
+
+ The length of `value` must be the number of instances,
+ and must agree with other existing fields in this object.
+ """
+ data_len = len(value)
+ if len(self._fields):
+ assert (
+ len(self) == data_len
+ ), "Adding a field of length {} to a Instances of length {}".format(data_len, len(self))
+ self._fields[name] = value
+
+ def has(self, name: str) -> bool:
+ """Returns whether the field called `name` exists."""
+ return name in self._fields
+
+ def remove(self, name: str) -> None:
+ """Remove the field called `name`."""
+ del self._fields[name]
+
+ def get(self, name: str) -> Any:
+ """Returns the field called `name`."""
+ return self._fields[name]
+
+ def get_fields(self) -> Dict[str, Any]:
+ """Get field.
+
+ Returns:
+ dict: a dict which maps names (str) to data of the fields
+
+ Modifying the returned dict will modify this instance.
+ """
+ return self._fields
+
+ # Tensor-like methods
+ def to(self, *args: Any, **kwargs: Any) -> "Instances":
+ """To device.
+
+ Returns:
+ Instances: all fields are called with a `to(device)`, if the field has this method.
+ """
+ ret = Instances(self._image_size)
+ for k, v in self._fields.items():
+ if hasattr(v, "to"):
+ v = v.to(*args, **kwargs)
+ ret.set(k, v)
+ return ret
+
+ def numpy(self):
+ ret = Instances(self._image_size)
+ for k, v in self._fields.items():
+ if hasattr(v, "numpy"):
+ v = v.numpy()
+ ret.set(k, v)
+ return ret
+
+ def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Instances":
+ """Get entry.
+
+ Args:
+ item: an index-like object and will be used to index all the fields.
+
+ Returns:
+ If `item` is a string, return the data in the corresponding field.
+ Otherwise, returns an `Instances` where all fields are indexed by `item`.
+ """
+ if type(item) == int:
+ if item >= len(self) or item < -len(self):
+ raise IndexError("Instances index out of range!")
+ else:
+ item = slice(item, None, len(self))
+
+ ret = Instances(self._image_size)
+ for k, v in self._fields.items():
+ ret.set(k, v[item])
+ return ret
+
+ def __len__(self) -> int:
+ for v in self._fields.values():
+ # use __len__ because len() has to be int and is not friendly to tracing
+ return v.__len__()
+ raise NotImplementedError("Empty Instances does not support __len__!")
+
+ def __iter__(self):
+ raise NotImplementedError("`Instances` object is not iterable!")
+
+ @staticmethod
+ def cat(instance_lists: List["Instances"]) -> "Instances":
+ """Concatenate instances.
+
+ Args:
+ instance_lists (list[Instances])
+
+ Returns:
+ Instances
+ """
+ assert all(isinstance(i, Instances) for i in instance_lists)
+ assert len(instance_lists) > 0
+ if len(instance_lists) == 1:
+ return instance_lists[0]
+
+ image_size = instance_lists[0].image_size
+ for i in instance_lists[1:]:
+ assert i.image_size == image_size
+ ret = Instances(image_size)
+ for k in instance_lists[0]._fields.keys():
+ values = [i.get(k) for i in instance_lists]
+ v0 = values[0]
+ if isinstance(v0, torch.Tensor):
+ values = torch.cat(values, dim=0)
+ elif isinstance(v0, list):
+ values = list(itertools.chain(*values))
+ elif hasattr(type(v0), "cat"):
+ values = type(v0).cat(values)
+ else:
+ raise ValueError("Unsupported type {} for concatenation".format(type(v0)))
+ ret.set(k, values)
+ return ret
+
+ def __str__(self) -> str:
+ s = self.__class__.__name__ + "("
+ s += "num_instances={}, ".format(len(self))
+ s += "image_height={}, ".format(self._image_size[0])
+ s += "image_width={}, ".format(self._image_size[1])
+ s += "fields=[{}])".format(", ".join((f"{k}: {v}" for k, v in self._fields.items())))
+ return s
+
+ __repr__ = __str__
diff --git a/ocl/cli/cli_utils.py b/ocl/cli/cli_utils.py
new file mode 100644
index 0000000..c4efccf
--- /dev/null
+++ b/ocl/cli/cli_utils.py
@@ -0,0 +1,29 @@
+import glob
+import os
+
+from hydra.core.hydra_config import HydraConfig
+
+
+def get_commandline_config_path():
+ """Get the path of a config path specified on the command line."""
+ hydra_cfg = HydraConfig.get()
+ config_sources = hydra_cfg.runtime.config_sources
+ config_path = None
+ for source in config_sources:
+ if source.schema == "file" and source.provider == "command-line":
+ config_path = source.path
+ break
+ return config_path
+
+
+def find_checkpoint(path):
+ """Find checkpoint in output path of previous run."""
+ checkpoints = glob.glob(
+ os.path.join(path, "lightning_logs", "version_*", "checkpoints", "*.ckpt")
+ )
+ checkpoints.sort()
+ # Return the last checkpoint.
+ # TODO (hornmax): If more than one checkpoint is stored this might not lead to the most recent
+ # checkpoint being loaded. Generally, I think this is ok as we still allow people to set the
+ # checkpoint manually.
+ return checkpoints[-1]
diff --git a/ocl/cli/compute_dataset_size.py b/ocl/cli/compute_dataset_size.py
new file mode 100644
index 0000000..5be7a91
--- /dev/null
+++ b/ocl/cli/compute_dataset_size.py
@@ -0,0 +1,75 @@
+"""Script to compute the size of a dataset.
+
+This is useful when subsampling data using transformations in order to determine the final dataset
+size. The size of the dataset is typically need when running distributed training in order to
+ensure that all nodes and gpu training processes are presented with the same number of batches.
+"""
+import dataclasses
+import logging
+import os
+from typing import Dict
+
+import hydra
+import hydra_zen
+import tqdm
+from pluggy import PluginManager
+
+import ocl.hooks
+from ocl.config.datasets import DataModuleConfig
+
+
+@dataclasses.dataclass
+class ComputeSizeConfig:
+ """Configuration of a training run."""
+
+ dataset: DataModuleConfig
+ plugins: Dict[str, Dict] = dataclasses.field(default_factory=dict)
+
+
+hydra.core.config_store.ConfigStore.instance().store(
+ name="compute_size_config",
+ node=ComputeSizeConfig,
+)
+
+
+@hydra.main(config_name="compute_size_config", config_path="../../configs", version_base="1.1")
+def compute_size(config: ComputeSizeConfig):
+ pm = PluginManager("ocl")
+ pm.add_hookspecs(ocl.hooks)
+
+ datamodule = hydra_zen.instantiate(config.dataset, hooks=pm.hook)
+ pm.register(datamodule)
+
+ plugins = hydra_zen.instantiate(config.plugins)
+ for plugin_name in sorted(plugins.keys())[::-1]:
+ pm.register(plugins[plugin_name])
+
+ # Compute dataset sizes
+ # TODO(hornmax): This is needed for webdataset shuffling, is there a way to make this more
+ # elegant and less specific?
+ os.environ["WDS_EPOCH"] = str(0)
+ train_size = sum(
+ 1
+ for _ in tqdm.tqdm(
+ datamodule.train_data_iterator(), desc="Reading train split", unit="samples"
+ )
+ )
+ logging.info("Train split size: %d", train_size)
+ val_size = sum(
+ 1
+ for _ in tqdm.tqdm(
+ datamodule.val_data_iterator(), desc="Reading validation split", unit="samples"
+ )
+ )
+ logging.info("Validation split size: %d", val_size)
+ test_size = sum(
+ 1
+ for _ in tqdm.tqdm(
+ datamodule.test_data_iterator(), desc="Reading test split", unit="samples"
+ )
+ )
+ logging.info("Test split size: %d", test_size)
+
+
+if __name__ == "__main__":
+ compute_size()
diff --git a/ocl/cli/eval.py b/ocl/cli/eval.py
new file mode 100644
index 0000000..0479594
--- /dev/null
+++ b/ocl/cli/eval.py
@@ -0,0 +1,156 @@
+"""Train a slot attention type model."""
+import dataclasses
+from typing import Any, Dict, Optional
+
+import hydra
+import hydra_zen
+import pytorch_lightning as pl
+from pluggy import PluginManager
+
+import ocl.hooks
+from ocl import base
+from ocl.combined_model import CombinedModel
+from ocl.config.datasets import DataModuleConfig
+from ocl.config.metrics import MetricConfig
+from ocl.plugins import Plugin
+from ocl.cli import cli_utils, eval_utils
+import torch
+
+TrainerConf = hydra_zen.builds(
+ pl.Trainer, max_epochs=100, zen_partial=False, populate_full_signature=True
+)
+
+
+@dataclasses.dataclass
+class TrainingConfig:
+ """Configuration of a training run."""
+
+ dataset: DataModuleConfig
+ models: Any # When provided with dict wrap in `utils.Combined`, otherwise interpret as model.
+ losses: Dict[str, Any]
+ visualizations: Dict[str, Any] = dataclasses.field(default_factory=dict)
+ plugins: Dict[str, Any] = dataclasses.field(default_factory=dict)
+ trainer: TrainerConf = TrainerConf
+ training_vis_frequency: Optional[int] = None
+ training_metrics: Optional[Dict[str, MetricConfig]] = None
+ evaluation_metrics: Optional[Dict[str, MetricConfig]] = None
+ load_checkpoint: Optional[str] = None
+ # load_model_weight: Optional[str] = None
+ seed: Optional[int] = None
+ experiment: Optional[Any] = None
+ root_output_folder: Optional[str] = None
+
+
+hydra.core.config_store.ConfigStore.instance().store(
+ name="training_config",
+ node=TrainingConfig,
+)
+
+
+def create_plugin_manager() -> PluginManager:
+ pm = PluginManager("ocl")
+ pm.add_hookspecs(ocl.hooks)
+ return pm
+
+
+def build_and_register_datamodule_from_config(
+ config: TrainingConfig,
+ hooks: base.PluggyHookRelay,
+ plugin_manager: Optional[PluginManager] = None,
+ **datamodule_kwargs,
+) -> pl.LightningDataModule:
+ datamodule = hydra_zen.instantiate(
+ config.dataset, hooks=hooks, _convert_="all", **datamodule_kwargs
+ )
+
+ if plugin_manager:
+ plugin_manager.register(datamodule)
+
+ return datamodule
+
+
+def build_and_register_plugins_from_config(
+ config: TrainingConfig, plugin_manager: Optional[PluginManager] = None
+) -> Dict[str, Plugin]:
+ plugins = hydra_zen.instantiate(config.plugins)
+ # Use lexicographical sorting to allow to influence registration order. This is necessary in
+ # some cases as certain plugins might need to be called before others. Pluggy calls hooks
+ # according to FILO (first in last out) and this is slightly unintuitive. We thus register
+ # plugins in reverse order to their sorting position, leading to a FIFO (first in first out)
+ # behavior with regard to the sorted position.
+ if plugin_manager:
+ for plugin_name in sorted(plugins.keys())[::-1]:
+ plugin_manager.register(plugins[plugin_name])
+
+ return plugins
+
+
+def build_model_from_config(
+ config: TrainingConfig,
+ hooks: base.PluggyHookRelay,
+ checkpoint_path: Optional[str] = None,
+) -> pl.LightningModule:
+ models = hydra_zen.instantiate(config.models, _convert_="all")
+ losses = hydra_zen.instantiate(config.losses, _convert_="all")
+ visualizations = hydra_zen.instantiate(config.visualizations, _convert_="all")
+
+ training_metrics = hydra_zen.instantiate(config.training_metrics)
+ evaluation_metrics = hydra_zen.instantiate(config.evaluation_metrics)
+
+ train_vis_freq = config.training_vis_frequency if config.training_vis_frequency else 100
+
+ if checkpoint_path is None:
+ model = CombinedModel(
+ models=models,
+ losses=losses,
+ visualizations=visualizations,
+ hooks=hooks,
+ training_metrics=training_metrics,
+ evaluation_metrics=evaluation_metrics,
+ vis_log_frequency=train_vis_freq,
+ )
+ else:
+ model = CombinedModel.load_from_checkpoint(
+ checkpoint_path,
+ strict=False,
+ models=models,
+ losses=losses,
+ visualizations=visualizations,
+ hooks=hooks,
+ training_metrics=training_metrics,
+ evaluation_metrics=evaluation_metrics,
+ vis_log_frequency=train_vis_freq,
+ )
+ return model
+
+
+@hydra.main(config_name="training_config", config_path="../../configs/", version_base="1.1")
+def train(config: TrainingConfig):
+ # Set all relevant random seeds. If `config.seed is None`, the function samples a random value.
+ # The function takes care of correctly distributing the seed across nodes in multi-node training,
+ # and assigns each dataloader worker a different random seed.
+ # IMPORTANTLY, we need to take care not to set a custom `worker_init_fn` function on the
+ # dataloaders (or take care of worker seeding ourselves).
+ pl.seed_everything(config.seed, workers=True)
+
+ pm = create_plugin_manager()
+
+
+ checkpoint_path = hydra.utils.to_absolute_path(config.load_checkpoint)
+ datamodule, model, pm = eval_utils.build_from_train_config(
+ config, checkpoint_path
+ )
+
+ trainer: pl.Trainer = hydra_zen.instantiate(
+ config.trainer,
+ _convert_="all",
+ enable_progress_bar=True,
+ gpus=[0],
+ )
+
+ print("******start validate model******")
+ trainer.validate(model, datamodule.val_dataloader())
+
+
+if __name__ == "__main__":
+ train()
diff --git a/ocl/cli/eval_utils.py b/ocl/cli/eval_utils.py
new file mode 100644
index 0000000..74a8d2e
--- /dev/null
+++ b/ocl/cli/eval_utils.py
@@ -0,0 +1,126 @@
+import pathlib
+import pickle
+from collections import defaultdict
+from typing import Any, Callable, Dict, List, Optional
+
+import numpy
+import pytorch_lightning as pl
+import torch
+
+from ocl import path_defaults
+from ocl.cli import train
+from ocl.utils.trees import get_tree_element
+
+
+def build_from_train_config(
+ config: train.TrainingConfig, checkpoint_path: Optional[str], seed: bool = True
+):
+ if seed:
+ pl.seed_everything(config.seed, workers=True)
+
+ pm = train.create_plugin_manager()
+ datamodule = train.build_and_register_datamodule_from_config(config, pm.hook, pm)
+ train.build_and_register_plugins_from_config(config, pm)
+ model = train.build_model_from_config(config, pm.hook, checkpoint_path)
+
+ return datamodule, model, pm
+
+
+class ExtractDataFromPredictions(pl.callbacks.Callback):
+ """Callback used for extracting model outputs during validation and prediction."""
+
+ def __init__(
+ self,
+ paths: List[str],
+ output_paths: Optional[List[str]] = None,
+ transform: Optional[Callable] = None,
+ max_samples: Optional[int] = None,
+ flatten_batches: bool = False,
+ ):
+ self.paths = paths
+ self.output_paths = output_paths if output_paths is not None else paths
+ self.transform = transform
+ self.max_samples = max_samples
+ self.flatten_batches = flatten_batches
+
+ self.outputs = defaultdict(list)
+ self._n_samples = 0
+
+ def _start(self):
+ self._n_samples = 0
+ self.outputs = defaultdict(list)
+
+ def _process_outputs(self, outputs, batch):
+ if self.max_samples is not None and self._n_samples >= self.max_samples:
+ return
+
+ data = {path_defaults.INPUT: batch, **outputs}
+ data = {path: get_tree_element(outputs, path.split(".")) for path in self.paths}
+
+ if self.transform:
+ data = self.transform(data)
+
+ first_path = True
+ for path in self.output_paths:
+ elems = data[path].detach().cpu()
+ if not self.flatten_batches:
+ elems = [elems]
+
+ for idx in range(len(elems)):
+ self.outputs[path].append(elems[idx])
+ if first_path:
+ self._n_samples += 1
+
+ first_path = False
+
+ def on_validation_start(self, trainer, pl_module):
+ self._start()
+
+ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
+ assert (
+ outputs is not None
+ ), "Model returned no outputs. Set `model.return_outputs_on_validation=True`"
+ self._process_outputs(outputs, batch)
+
+ def on_predict_start(self, trainer, pl_module):
+ self._start()
+
+ def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
+ self._process_outputs(outputs, batch)
+
+ def get_outputs(self) -> List[Dict[str, Any]]:
+ state = []
+ for idx in range(self._n_samples):
+ state.append({})
+ for key, values in self.outputs.items():
+ state[-1][key] = values[idx]
+
+ return state
+
+
+def save_outputs(dir_path: str, outputs: List[Dict[str, Any]], verbose: bool = False):
+ """Save outputs to disk in numpy or pickle format."""
+ dir_path = pathlib.Path(dir_path)
+ dir_path.mkdir(parents=True, exist_ok=True)
+
+ def get_path(path, prefix, key, extension):
+ return str(path / f"{prefix}.{key}.{extension}")
+
+ idx_fmt = "{:0" + str(len(str(len(outputs)))) + "d}" # Get number of total digits
+ for idx, entry in enumerate(outputs):
+ idx_prefix = idx_fmt.format(idx)
+ for key, value in entry.items():
+ if isinstance(value, torch.Tensor):
+ value = value.numpy()
+
+ if isinstance(value, numpy.ndarray):
+ path = get_path(dir_path, idx_prefix, key, "npy")
+ if verbose:
+ print(f"Saving numpy array to {path}.")
+ numpy.save(path, value)
+ else:
+ path = get_path(dir_path, idx_prefix, key, "pkl")
+ if verbose:
+ print(f"Saving pickle to {path}.")
+ with open(path, "wb") as f:
+ pickle.dump(value, f)
diff --git a/ocl/cli/train.py b/ocl/cli/train.py
new file mode 100644
index 0000000..a072e24
--- /dev/null
+++ b/ocl/cli/train.py
@@ -0,0 +1,158 @@
+"""Train a slot attention type model."""
+import dataclasses
+from typing import Any, Dict, Optional
+
+import hydra
+import hydra_zen
+import pytorch_lightning as pl
+from pluggy import PluginManager
+
+import ocl.hooks
+from ocl import base
+from ocl.combined_model import CombinedModel
+from ocl.config.datasets import DataModuleConfig
+from ocl.config.metrics import MetricConfig
+from ocl.plugins import Plugin
+
+TrainerConf = hydra_zen.builds(
+ pl.Trainer, max_epochs=100, zen_partial=False, populate_full_signature=True
+)
+
+
+@dataclasses.dataclass
+class TrainingConfig:
+ """Configuration of a training run."""
+
+ dataset: DataModuleConfig
+ models: Any # When provided with dict wrap in `utils.Combined`, otherwise interpret as model.
+ losses: Dict[str, Any]
+ visualizations: Dict[str, Any] = dataclasses.field(default_factory=dict)
+ plugins: Dict[str, Any] = dataclasses.field(default_factory=dict)
+ trainer: TrainerConf = TrainerConf
+ training_vis_frequency: Optional[int] = None
+ training_metrics: Optional[Dict[str, MetricConfig]] = None
+ evaluation_metrics: Optional[Dict[str, MetricConfig]] = None
+ load_checkpoint: Optional[str] = None
+ seed: Optional[int] = None
+ experiment: Optional[Any] = None
+ root_output_folder: Optional[str] = None
+
+
+hydra.core.config_store.ConfigStore.instance().store(
+ name="training_config",
+ node=TrainingConfig,
+)
+
+
+def create_plugin_manager() -> PluginManager:
+ pm = PluginManager("ocl")
+ pm.add_hookspecs(ocl.hooks)
+ return pm
+
+
+def build_and_register_datamodule_from_config(
+ config: TrainingConfig,
+ hooks: base.PluggyHookRelay,
+ plugin_manager: Optional[PluginManager] = None,
+ **datamodule_kwargs,
+) -> pl.LightningDataModule:
+ datamodule = hydra_zen.instantiate(
+ config.dataset, hooks=hooks, _convert_="all", **datamodule_kwargs
+ )
+
+ if plugin_manager:
+ plugin_manager.register(datamodule)
+
+ return datamodule
+
+
+def build_and_register_plugins_from_config(
+ config: TrainingConfig, plugin_manager: Optional[PluginManager] = None
+) -> Dict[str, Plugin]:
+ plugins = hydra_zen.instantiate(config.plugins)
+ # Use lexicographical sorting to allow to influence registration order. This is necessary in
+ # some cases as certain plugins might need to be called before others. Pluggy calls hooks
+ # according to FILO (first in last out) and this is slightly unintuitive. We thus register
+ # plugins in reverse order to their sorting position, leading to a FIFO (first in first out)
+ # behavior with regard to the sorted position.
+ if plugin_manager:
+ for plugin_name in sorted(plugins.keys())[::-1]:
+ plugin_manager.register(plugins[plugin_name])
+
+ return plugins
+
+
+def build_model_from_config(
+ config: TrainingConfig,
+ hooks: base.PluggyHookRelay,
+ checkpoint_path: Optional[str] = None,
+) -> pl.LightningModule:
+ models = hydra_zen.instantiate(config.models, _convert_="all")
+ losses = hydra_zen.instantiate(config.losses, _convert_="all")
+ visualizations = hydra_zen.instantiate(config.visualizations, _convert_="all")
+
+ training_metrics = hydra_zen.instantiate(config.training_metrics)
+ evaluation_metrics = hydra_zen.instantiate(config.evaluation_metrics)
+
+ train_vis_freq = config.training_vis_frequency if config.training_vis_frequency else 100
+
+ if checkpoint_path is None:
+ model = CombinedModel(
+ models=models,
+ losses=losses,
+ visualizations=visualizations,
+ hooks=hooks,
+ training_metrics=training_metrics,
+ evaluation_metrics=evaluation_metrics,
+ vis_log_frequency=train_vis_freq,
+ )
+ else:
+ model = CombinedModel.load_from_checkpoint(
+ checkpoint_path,
+ strict=False,
+ models=models,
+ losses=losses,
+ visualizations=visualizations,
+ hooks=hooks,
+ training_metrics=training_metrics,
+ evaluation_metrics=evaluation_metrics,
+ vis_log_frequency=train_vis_freq,
+ )
+ return model
+
+
+@hydra.main(config_name="training_config", config_path="../../configs/", version_base="1.1")
+def train(config: TrainingConfig):
+ # Set all relevant random seeds. If `config.seed is None`, the function samples a random value.
+ # The function takes care of correctly distributing the seed across nodes in multi-node training,
+ # and assigns each dataloader worker a different random seed.
+ # IMPORTANTLY, we need to take care not to set a custom `worker_init_fn` function on the
+ # dataloaders (or take care of worker seeding ourselves).
+ pl.seed_everything(config.seed, workers=True)
+
+ pm = create_plugin_manager()
+
+ datamodule = build_and_register_datamodule_from_config(config, pm.hook, pm)
+
+ build_and_register_plugins_from_config(config, pm)
+
+ if config.load_checkpoint:
+ checkpoint_path = hydra.utils.to_absolute_path(config.load_checkpoint)
+ else:
+ checkpoint_path = None
+
+ model = build_model_from_config(config, pm.hook)
+
+ callbacks = hydra_zen.instantiate(config.trainer.callbacks, _convert_="all")
+ callbacks = callbacks if callbacks else []
+ if config.trainer.logger is not False:
+ lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval="step")
+ callbacks.append(lr_monitor)
+
+ trainer: pl.Trainer = hydra_zen.instantiate(config.trainer, callbacks=callbacks, _convert_="all")
+
+ trainer.fit(model, datamodule=datamodule, ckpt_path=checkpoint_path)
+
+
+if __name__ == "__main__":
+ train()
diff --git a/ocl/cli/train_adaslot.py b/ocl/cli/train_adaslot.py
new file mode 100644
index 0000000..c3c6f44
--- /dev/null
+++ b/ocl/cli/train_adaslot.py
@@ -0,0 +1,166 @@
+"""Train a slot attention type model."""
+import dataclasses
+from typing import Any, Dict, Optional
+
+import hydra
+import hydra_zen
+import pytorch_lightning as pl
+from pluggy import PluginManager
+
+import ocl.hooks
+from ocl import base
+from ocl.combined_model import CombinedModel
+from ocl.config.datasets import DataModuleConfig
+from ocl.config.metrics import MetricConfig
+from ocl.plugins import Plugin
+import torch
+
+TrainerConf = hydra_zen.builds(
+ pl.Trainer, max_epochs=100, zen_partial=False, populate_full_signature=True
+)
+
+
+@dataclasses.dataclass
+class TrainingConfig:
+ """Configuration of a training run."""
+
+ dataset: DataModuleConfig
+ models: Any # When provided with dict wrap in `utils.Combined`, otherwise interpret as model.
+ losses: Dict[str, Any]
+ visualizations: Dict[str, Any] = dataclasses.field(default_factory=dict)
+ plugins: Dict[str, Any] = dataclasses.field(default_factory=dict)
+ trainer: TrainerConf = TrainerConf
+ training_vis_frequency: Optional[int] = None
+ training_metrics: Optional[Dict[str, MetricConfig]] = None
+ evaluation_metrics: Optional[Dict[str, MetricConfig]] = None
+ load_checkpoint: Optional[str] = None
+ load_model_weight: Optional[str] = None
+ seed: Optional[int] = None
+ experiment: Optional[Any] = None
+ root_output_folder: Optional[str] = None
+
+
+hydra.core.config_store.ConfigStore.instance().store(
+ name="training_config",
+ node=TrainingConfig,
+)
+
+
+def create_plugin_manager() -> PluginManager:
+ pm = PluginManager("ocl")
+ pm.add_hookspecs(ocl.hooks)
+ return pm
+
+
+def build_and_register_datamodule_from_config(
+ config: TrainingConfig,
+ hooks: base.PluggyHookRelay,
+ plugin_manager: Optional[PluginManager] = None,
+ **datamodule_kwargs,
+) -> pl.LightningDataModule:
+ datamodule = hydra_zen.instantiate(
+ config.dataset, hooks=hooks, _convert_="all", **datamodule_kwargs
+ )
+
+ if plugin_manager:
+ plugin_manager.register(datamodule)
+
+ return datamodule
+
+
+def build_and_register_plugins_from_config(
+ config: TrainingConfig, plugin_manager: Optional[PluginManager] = None
+) -> Dict[str, Plugin]:
+ plugins = hydra_zen.instantiate(config.plugins)
+ # Use lexicographical sorting to allow to influence registration order. This is necessary in
+ # some cases as certain plugins might need to be called before others. Pluggy calls hooks
+ # according to FILO (first in last out) and this is slightly unintuitive. We thus register
+ # plugins in reverse order to their sorting position, leading to a FIFO (first in first out)
+ # behavior with regard to the sorted position.
+ if plugin_manager:
+ for plugin_name in sorted(plugins.keys())[::-1]:
+ plugin_manager.register(plugins[plugin_name])
+
+ return plugins
+
+
+def build_model_from_config(
+ config: TrainingConfig,
+ hooks: base.PluggyHookRelay,
+ checkpoint_path: Optional[str] = None,
+) -> pl.LightningModule:
+ models = hydra_zen.instantiate(config.models, _convert_="all")
+ losses = hydra_zen.instantiate(config.losses, _convert_="all")
+ visualizations = hydra_zen.instantiate(config.visualizations, _convert_="all")
+
+ training_metrics = hydra_zen.instantiate(config.training_metrics)
+ evaluation_metrics = hydra_zen.instantiate(config.evaluation_metrics)
+
+ train_vis_freq = config.training_vis_frequency if config.training_vis_frequency else 100
+
+ if checkpoint_path is None:
+ model = CombinedModel(
+ models=models,
+ losses=losses,
+ visualizations=visualizations,
+ hooks=hooks,
+ training_metrics=training_metrics,
+ evaluation_metrics=evaluation_metrics,
+ vis_log_frequency=train_vis_freq,
+ )
+ else:
+ model = CombinedModel.load_from_checkpoint(
+ checkpoint_path,
+ strict=False,
+ models=models,
+ losses=losses,
+ visualizations=visualizations,
+ hooks=hooks,
+ training_metrics=training_metrics,
+ evaluation_metrics=evaluation_metrics,
+ vis_log_frequency=train_vis_freq,
+ )
+ return model
+
+
+@hydra.main(config_name="training_config", config_path="../../configs/", version_base="1.1")
+def train(config: TrainingConfig):
+ # Set all relevant random seeds. If `config.seed is None`, the function samples a random value.
+ # The function takes care of correctly distributing the seed across nodes in multi-node training,
+ # and assigns each dataloader worker a different random seed.
+ # IMPORTANTLY, we need to take care not to set a custom `worker_init_fn` function on the
+ # dataloaders (or take care of worker seeding ourselves).
+ pl.seed_everything(config.seed, workers=True)
+
+ pm = create_plugin_manager()
+
+ datamodule = build_and_register_datamodule_from_config(config, pm.hook, pm)
+
+ build_and_register_plugins_from_config(config, pm)
+
+ if config.load_checkpoint:
+ checkpoint_path = hydra.utils.to_absolute_path(config.load_checkpoint)
+ else:
+ checkpoint_path = None
+
+
+ model = build_model_from_config(config, pm.hook)
+ if config.load_model_weight:
+ model_weight_path = hydra.utils.to_absolute_path(config.load_model_weight)
+ ckpt_weight = torch.load(model_weight_path, map_location=torch.device('cpu'))["state_dict"]
+ model.load_state_dict(ckpt_weight, strict=False)
+ else:
+ model_weight_path = None
+ callbacks = hydra_zen.instantiate(config.trainer.callbacks, _convert_="all")
+ callbacks = callbacks if callbacks else []
+ if config.trainer.logger is not False:
+ lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval="step")
+ callbacks.append(lr_monitor)
+
+ trainer: pl.Trainer = hydra_zen.instantiate(config.trainer, callbacks=callbacks, _convert_="all")
+
+ trainer.fit(model, datamodule=datamodule, ckpt_path=checkpoint_path)
+
+
+if __name__ == "__main__":
+ train()
diff --git a/ocl/combined_model.py b/ocl/combined_model.py
new file mode 100644
index 0000000..365a722
--- /dev/null
+++ b/ocl/combined_model.py
@@ -0,0 +1,177 @@
+"""Implementation of combined model."""
+from __future__ import annotations
+
+from functools import partial
+from typing import TYPE_CHECKING, Any, Dict, Union
+
+import pytorch_lightning as pl
+import torch
+from torch import nn
+
+from ocl import base, path_defaults
+from ocl.utils.routing import Combined
+from ocl.utils.trees import walk_tree_with_paths
+from ocl.visualization_types import Visualization
+import os
+# from slot_attention.tasks import Task
+# import ipdb
+if TYPE_CHECKING:
+ import torchmetrics
+
+
+class CombinedModel(pl.LightningModule):
+ def __init__(
+ self,
+ models: Union[Dict[str, Any], nn.Module],
+ losses: Dict[str, Any],
+ visualizations: Dict[str, Any],
+ hooks: base.PluggyHookRelay,
+ training_metrics: Dict[str, torchmetrics.Metric] = None,
+ evaluation_metrics: Dict[str, torchmetrics.Metric] = None,
+ vis_log_frequency: int = 100,
+ ):
+ super().__init__()
+ if isinstance(models, Dict):
+ models = Combined(models)
+ self.models = models
+ self.losses = losses
+ self.visualizations = visualizations
+ self.hooks = hooks
+ self.vis_log_frequency = vis_log_frequency
+ self.return_outputs_on_validation = False
+
+ if training_metrics is None:
+ training_metrics = {}
+ self.training_metrics = torch.nn.ModuleDict(training_metrics)
+
+ if evaluation_metrics is None:
+ evaluation_metrics = {}
+ self.evaluation_metrics = torch.nn.ModuleDict(evaluation_metrics)
+
+ def configure_optimizers(self):
+ return self.hooks.configure_optimizers(model=self)
+
+ def __getattribute__(self, name):
+ """Forward pytorch lightning module hooks to the plugin manager.
+
+ We need to implement `__getattribute__` as the model hooks are defined in a superclass of
+ `pl.LightningModule` and thus `__getattr__` would never get called for them. This makes the
+ call a bit more clumsy.
+ """
+ if not name.startswith("__") and hasattr(pl.core.hooks.ModelHooks, name):
+ # A pytorch lighting hook is being called.
+ try:
+ hook_caller = getattr(self.hooks, name)
+ return partial(hook_caller, model=self)
+ except AttributeError:
+ pass
+ return super().__getattribute__(name)
+
+ def forward(self, input_data: dict):
+ # Maybe we should use something like a read only dict to prevent existing keys from being
+ # overwritten.
+ data: Dict[str, Any]
+ data = {
+ path_defaults.INPUT: input_data,
+ path_defaults.GLOBAL_STEP: self.global_step,
+ path_defaults.MODEL: self,
+ }
+ return self.models(inputs=data)
+
+ def _compute_losses(self, inputs, phase="train"):
+ quantities_to_log = {}
+ # We write additional loss outputs directly into the inputs dict, and thus do not need to
+ # return them.
+ outputs = inputs["losses"] = {}
+ for name, loss in self.losses.items():
+ out = loss(inputs=inputs)
+ if isinstance(out, tuple):
+ # Additional outputs that should be logged for later access.
+ # Some visualizations require having access to loss quantities, thus we need to save
+ # them for later here.
+ out, additional_outputs = out
+ outputs[name] = additional_outputs
+ quantities_to_log[f"{phase}/{name}"] = out
+
+ losses = []
+ for loss in quantities_to_log.values():
+ losses.append(loss)
+
+ total_loss = torch.stack(losses).sum()
+
+ # Log total loss only if there is more than one task
+ if len(losses) > 1:
+ quantities_to_log[f"{phase}/loss_total"] = total_loss
+
+ return total_loss, quantities_to_log
+
+ def predict_step(self, batch, batch_idx):
+ outputs = self(batch)
+ # Remove things not needed in prediction output.
+ del outputs[path_defaults.MODEL], outputs[path_defaults.GLOBAL_STEP]
+ return outputs
+
+ def training_step(self, batch, batch_idx):
+ batch_size = batch["batch_size"]
+ outputs = self(batch)
+ total_loss, quantities_to_log = self._compute_losses(outputs)
+
+ quantities_to_log.update(self._compute_metrics(outputs, self.training_metrics))
+ self.log_dict(quantities_to_log, on_step=True, on_epoch=False, batch_size=batch_size)
+
+ if self.trainer.global_step % self.vis_log_frequency == 0:
+ self._log_visualizations(outputs)
+
+ return total_loss
+
+ def validation_step(self, batch, batch_idx):
+ batch_size = batch["batch_size"]
+ outputs = self(batch)
+ total_loss, quantities_to_log = self._compute_losses(outputs, phase="val")
+
+ quantities_to_log.update(
+ self._compute_metrics(outputs, self.evaluation_metrics, phase="val")
+ )
+ self.log_dict(
+ quantities_to_log, on_step=False, on_epoch=True, prog_bar=True, batch_size=batch_size
+ )
+
+ if batch_idx == 0:
+ self._log_visualizations(outputs, phase="val")
+
+ if self.return_outputs_on_validation:
+ return outputs # Used for saving model outputs during eval
+ else:
+ return None
+
+ def _compute_metrics(self, outputs, metric_fns, phase="train"):
+ metrics = {}
+ if len(metric_fns) > 0:
+ for metric_name, metric in metric_fns.items():
+ if phase == "val":
+ # Call update instead of forward to avoid unnecessary metric compute on batch.
+ metric.update(**outputs)
+ else:
+ metric(**outputs)
+ metrics[f"{phase}/{metric_name}"] = metric
+
+ return metrics
+
+ def _log_visualizations(self, outputs, phase="train"):
+ if self.logger is None:
+ return
+ logger_experiment = self.logger.experiment
+ visualizations = {}
+ for name, vis in self.visualizations.items():
+ visualizations[name] = vis(inputs=outputs)
+
+ visualization_iterator = walk_tree_with_paths(
+ visualizations, path=None, instance_check=lambda a: isinstance(a, Visualization)
+ )
+ for path, vis in visualization_iterator:
+ str_path = ".".join(path)
+ vis.add_to_experiment(
+ experiment=logger_experiment,
+ tag=f"{phase}/{str_path}",
+ global_step=self.trainer.global_step,
+ )
diff --git a/ocl/conditioning.py b/ocl/conditioning.py
new file mode 100644
index 0000000..d69271b
--- /dev/null
+++ b/ocl/conditioning.py
@@ -0,0 +1,288 @@
+"""Implementation of conditioning approaches for slots."""
+from typing import Callable, Optional, Tuple
+
+import numpy as np
+import torch
+from torch import nn
+
+from ocl import base, path_defaults
+from ocl.utils.routing import RoutableMixin
+
+
+class RandomConditioning(base.Conditioning, RoutableMixin):
+ """Random conditioning with potentially learnt mean and stddev."""
+
+ def __init__(
+ self,
+ object_dim: int,
+ n_slots: int,
+ learn_mean: bool = True,
+ learn_std: bool = True,
+ mean_init: Optional[Callable[[torch.Tensor], None]] = None,
+ logsigma_init: Optional[Callable[[torch.Tensor], None]] = None,
+ batch_size_path: Optional[str] = path_defaults.BATCH_SIZE,
+ ):
+ base.Conditioning.__init__(self)
+ RoutableMixin.__init__(self, {"batch_size": batch_size_path})
+ self.n_slots = n_slots
+ self.object_dim = object_dim
+
+ if learn_mean:
+ self.slots_mu = nn.Parameter(torch.zeros(1, 1, object_dim))
+ else:
+ self.register_buffer("slots_mu", torch.zeros(1, 1, object_dim))
+
+ if learn_std:
+ self.slots_logsigma = nn.Parameter(torch.zeros(1, 1, object_dim))
+ else:
+ self.register_buffer("slots_logsigma", torch.zeros(1, 1, object_dim))
+
+ if mean_init is None:
+ mean_init = nn.init.xavier_uniform_
+ if logsigma_init is None:
+ logsigma_init = nn.init.xavier_uniform_
+
+ with torch.no_grad():
+ mean_init(self.slots_mu)
+ logsigma_init(self.slots_logsigma)
+
+ @RoutableMixin.route
+ def forward(self, batch_size: int) -> base.ConditioningOutput:
+ mu = self.slots_mu.expand(batch_size, self.n_slots, -1)
+ sigma = self.slots_logsigma.exp().expand(batch_size, self.n_slots, -1)
+ return mu + sigma * torch.randn_like(mu)
+
+
+class LearntConditioning(base.Conditioning, RoutableMixin):
+ """Conditioning with a learnt set of slot initializations, similar to DETR."""
+
+ def __init__(
+ self,
+ object_dim: int,
+ n_slots: int,
+ slot_init: Optional[Callable[[torch.Tensor], None]] = None,
+ batch_size_path: Optional[str] = path_defaults.BATCH_SIZE,
+ ):
+ base.Conditioning.__init__(self)
+ RoutableMixin.__init__(self, {"batch_size": batch_size_path})
+ self.n_slots = n_slots
+ self.object_dim = object_dim
+
+ self.slots = nn.Parameter(torch.zeros(1, n_slots, object_dim))
+
+ if slot_init is None:
+ slot_init = nn.init.normal_
+
+ with torch.no_grad():
+ slot_init(self.slots)
+
+ @RoutableMixin.route
+ def forward(self, batch_size: int) -> base.ConditioningOutput:
+ return self.slots.expand(batch_size, -1, -1)
+
+
+class RandomConditioningWithQMCSampling(RandomConditioning):
+ """Random conditioning with learnt mean and stddev using Quasi-Monte Carlo (QMC) samples."""
+
+ def __init__(
+ self,
+ object_dim: int,
+ n_slots: int,
+ learn_mean: bool = True,
+ learn_std: bool = True,
+ mean_init: Optional[Callable[[torch.Tensor], None]] = None,
+ logsigma_init: Optional[Callable[[torch.Tensor], None]] = None,
+ batch_size_path: Optional[str] = path_defaults.BATCH_SIZE,
+ ):
+ super().__init__(
+ object_dim,
+ n_slots,
+ learn_mean,
+ learn_std,
+ mean_init,
+ logsigma_init,
+ batch_size_path=batch_size_path,
+ )
+
+ import scipy.stats # Import lazily because scipy takes some time to import
+
+ self.randn_rng = scipy.stats.qmc.MultivariateNormalQMC(mean=np.zeros(object_dim))
+
+ def _randn(self, *args: Tuple[int]) -> torch.Tensor:
+ n_elements = np.prod(args)
+ # QMC sampler needs to sample powers of 2 numbers at a time
+ n_elements_rounded2 = 2 ** int(np.ceil(np.log2(n_elements)))
+ z = self.randn_rng.random(n_elements_rounded2)[:n_elements]
+
+ return torch.from_numpy(z).view(*args, -1)
+
+ @RoutableMixin.route
+ def forward(self, batch_size: int) -> base.ConditioningOutput:
+ mu = self.slots_mu.expand(batch_size, self.n_slots, -1)
+ sigma = self.slots_logsigma.exp().expand(batch_size, self.n_slots, -1)
+
+ z = self._randn(batch_size, self.n_slots).to(mu, non_blocking=True)
+ return mu + sigma * z
+
+
+class SlotwiseLearntConditioning(base.Conditioning, RoutableMixin):
+ """Random conditioning with learnt mean and stddev for each slot.
+
+ Removes permutation equivariance compared to the original slot attention conditioning.
+ """
+
+ def __init__(
+ self,
+ object_dim: int,
+ n_slots: int,
+ mean_init: Optional[Callable[[torch.Tensor], None]] = None,
+ logsigma_init: Optional[Callable[[torch.Tensor], None]] = None,
+ batch_size_path: Optional[str] = path_defaults.BATCH_SIZE,
+ ):
+ base.Conditioning.__init__(self)
+ RoutableMixin.__init__(self, {"batch_size": batch_size_path})
+ self.n_slots = n_slots
+ self.object_dim = object_dim
+
+ self.slots_mu = nn.Parameter(torch.zeros(1, n_slots, object_dim))
+ self.slots_logsigma = nn.Parameter(torch.zeros(1, n_slots, object_dim))
+
+ if mean_init is None:
+ mean_init = nn.init.normal_
+ if logsigma_init is None:
+ logsigma_init = nn.init.xavier_uniform_
+
+ with torch.no_grad():
+ mean_init(self.slots_mu)
+ logsigma_init(self.slots_logsigma)
+
+ @RoutableMixin.route
+ def forward(self, batch_size: int) -> base.ConditioningOutput:
+ mu = self.slots_mu.expand(batch_size, -1, -1)
+ sigma = self.slots_logsigma.exp().expand(batch_size, -1, -1)
+ return mu + sigma * torch.randn_like(mu)
+
+
+# NOTE: This is required to load a pre-trained SAVi model. Can be removed when retrain savi
+class MLP(nn.Module):
+ def __init__(
+ self,
+ input_size: int, # FIXME: added because or else can't instantiate submodules
+ hidden_size: int,
+ output_size: int, # if not given, should be inputs.shape[-1] at forward
+ num_hidden_layers: int = 1,
+ activation_fn: nn.Module = nn.ReLU,
+ layernorm: Optional[str] = None,
+ activate_output: bool = False,
+ residual: bool = False,
+ weight_init=None,
+ ):
+ super().__init__()
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.output_size = output_size
+ self.num_hidden_layers = num_hidden_layers
+ self.activation_fn = activation_fn
+ self.layernorm = layernorm
+ self.activate_output = activate_output
+ self.residual = residual
+ self.weight_init = weight_init
+ if self.layernorm == "pre":
+ self.layernorm_module = nn.LayerNorm(input_size, eps=1e-6)
+ elif self.layernorm == "post":
+ self.layernorm_module = nn.LayerNorm(output_size, eps=1e-6)
+ # mlp
+ self.model = nn.ModuleList()
+ self.model.add_module("dense_mlp_0", nn.Linear(self.input_size, self.hidden_size))
+ self.model.add_module("dense_mlp_0_act", self.activation_fn())
+ for i in range(1, self.num_hidden_layers):
+ self.model.add_module(f"den_mlp_{i}", nn.Linear(self.hidden_size, self.hidden_size))
+ self.model.add_module(f"dense_mlp_{i}_act", self.activation_fn())
+ self.model.add_module(
+ f"dense_mlp_{self.num_hidden_layers}", nn.Linear(self.hidden_size, self.output_size)
+ )
+ if self.activate_output:
+ self.model.add_module(f"dense_mlp_{self.num_hidden_layers}_act", self.activation_fn())
+ for name, module in self.model.named_children():
+ if "act" not in name:
+ nn.init.xavier_uniform_(module.weight)
+
+ def forward(self, inputs: torch.Tensor, train: bool = False) -> torch.Tensor:
+ del train # Unused
+
+ x = inputs
+ if self.layernorm == "pre":
+ x = self.layernorm_module(x)
+ for layer in self.model:
+ x = layer(x)
+ if self.residual:
+ x = x + inputs
+ if self.layernorm == "post":
+ x = self.layernorm_module(x)
+ return x
+
+
+class CoordinateEncoderStateInit(base.Conditioning, RoutableMixin):
+ """State init that encodes bounding box corrdinates as conditional input.
+
+ Attributes:
+ embedding_transform: A nn.Module that is applied on inputs (bounding boxes).
+ prepend_background: Boolean flag' whether to prepend a special, zero-valued
+ background bounding box to the input. Default: False.
+ center_of_mass: Boolean flag; whether to convert bounding boxes to center
+ of mass coordinates. Default: False.
+ background_value: Default value to fill in the background.
+ """
+
+ def __init__(
+ self,
+ object_dim: int,
+ prepend_background: bool = True,
+ center_of_mass: bool = False,
+ background_value: float = 0.0,
+ batch_size_path: Optional[str] = path_defaults.BATCH_SIZE,
+ ):
+ base.Conditioning.__init__(self)
+ RoutableMixin.__init__(self, {"batch_size": batch_size_path})
+
+ # self.embedding_transform = torchvision.ops.MLP(4, [256, 128], norm_layer=None)
+ self.embedding_transform = MLP(
+ input_size=4, hidden_size=256, output_size=128, layernorm=None
+ )
+ self.prepend_background = prepend_background
+ self.center_of_mass = center_of_mass
+ self.background_value = background_value
+ self.object_dim = object_dim
+
+ @RoutableMixin.route
+ def forward(self, target_bbox: torch.Tensor, batch_size: int) -> base.ConditioningOutput:
+ del batch_size # Unused.
+
+ # inputs.shape = (batch_size, seq_len, bboxes, 4)
+ inputs = target_bbox[:, 0] # Only condition on first time step.
+ # inputs.shape = (batch_size, bboxes, 4)
+ if self.prepend_background:
+ # Adds a fake background box [0, 0, 0, 0] at the beginning.
+ # [tianjux] NOTE: where is the logic?
+ batch_size = inputs.shape[0]
+
+ # Encode the background as specified by the background_value.
+ background = torch.full(
+ (batch_size, 1, 4),
+ self.background_value,
+ dtype=inputs.dtype,
+ device=inputs.get_device(),
+ )
+
+ inputs = torch.cat([background, inputs], dim=1)
+ # inputs = torch.cat([inputs, background], dim=1)
+
+ if self.center_of_mass:
+ y_pos = (inputs[:, :, 0] + inputs[:, :, 2]) / 2
+ x_pos = (inputs[:, :, 1] + inputs[:, :, 3]) / 2
+ inputs = torch.stack([y_pos, x_pos], dim=-1)
+
+ slots = self.embedding_transform(inputs)
+ # duplicated_slots = torch.cat([slots, slots], dim=1)
+
+ return slots
diff --git a/ocl/config/__init__.py b/ocl/config/__init__.py
new file mode 100644
index 0000000..3f1692a
--- /dev/null
+++ b/ocl/config/__init__.py
@@ -0,0 +1,39 @@
+from hydra.core.config_store import ConfigStore
+from omegaconf import OmegaConf
+
+from ocl.config import (
+ conditioning,
+ datasets,
+ feature_extractors,
+ metrics,
+ neural_networks,
+ optimizers,
+ perceptual_groupings,
+ plugins,
+ predictor,
+ utils,
+)
+
+config_store = ConfigStore.instance()
+
+conditioning.register_configs(config_store)
+
+datasets.register_configs(config_store)
+datasets.register_resolvers(OmegaConf)
+
+feature_extractors.register_configs(config_store)
+
+metrics.register_configs(config_store)
+
+neural_networks.register_configs(config_store)
+
+optimizers.register_configs(config_store)
+
+perceptual_groupings.register_configs(config_store)
+predictor.register_configs(config_store)
+
+plugins.register_configs(config_store)
+plugins.register_resolvers(OmegaConf)
+
+utils.register_configs(config_store)
+utils.register_resolvers(OmegaConf)
diff --git a/ocl/config/conditioning.py b/ocl/config/conditioning.py
new file mode 100644
index 0000000..41c1f9e
--- /dev/null
+++ b/ocl/config/conditioning.py
@@ -0,0 +1,58 @@
+"""Configuration of slot conditionings."""
+import dataclasses
+
+from hydra_zen import builds
+from omegaconf import SI
+
+from ocl import conditioning
+
+
+@dataclasses.dataclass
+class ConditioningConfig:
+ """Base class for conditioning module configuration."""
+
+
+# Unfortunately, we cannot define object_dim as part of the base config class as this prevents using
+# required positional arguments in all subclasses. We thus instead pass them here.
+LearntConditioningConfig = builds(
+ conditioning.LearntConditioning,
+ object_dim=SI("${perceptual_grouping.object_dim}"),
+ builds_bases=(ConditioningConfig,),
+ populate_full_signature=True,
+)
+
+RandomConditioningConfig = builds(
+ conditioning.RandomConditioning,
+ object_dim=SI("${perceptual_grouping.object_dim}"),
+ builds_bases=(ConditioningConfig,),
+ populate_full_signature=True,
+)
+
+RandomConditioningWithQMCSamplingConfig = builds(
+ conditioning.RandomConditioningWithQMCSampling,
+ object_dim=SI("${perceptual_grouping.object_dim}"),
+ builds_bases=(ConditioningConfig,),
+ populate_full_signature=True,
+)
+
+SlotwiseLearntConditioningConfig = builds(
+ conditioning.SlotwiseLearntConditioning,
+ object_dim=SI("${perceptual_grouping.object_dim}"),
+ builds_bases=(ConditioningConfig,),
+ populate_full_signature=True,
+)
+
+
+def register_configs(config_store):
+ config_store.store(group="schemas", name="conditioning", node=ConditioningConfig)
+
+ config_store.store(group="conditioning", name="learnt", node=LearntConditioningConfig)
+ config_store.store(group="conditioning", name="random", node=RandomConditioningConfig)
+ config_store.store(
+ group="conditioning",
+ name="random_with_qmc_sampling",
+ node=RandomConditioningWithQMCSamplingConfig,
+ )
+ config_store.store(
+ group="conditioning", name="slotwise_learnt_random", node=SlotwiseLearntConditioningConfig
+ )
diff --git a/ocl/config/datasets.py b/ocl/config/datasets.py
new file mode 100644
index 0000000..0e19553
--- /dev/null
+++ b/ocl/config/datasets.py
@@ -0,0 +1,82 @@
+"""Register all dataset related configs."""
+import dataclasses
+import os
+
+import yaml
+from hydra.utils import to_absolute_path
+from hydra_zen import builds
+
+from ocl import datasets
+
+
+def get_region():
+ """Determine the region this EC2 instance is running in.
+
+ Returns None if not running on an EC2 instance.
+ """
+ import requests
+
+ try:
+ r = requests.get(
+ "http://169.254.169.254/latest/dynamic/instance-identity/document", timeout=0.5
+ )
+ response_json = r.json()
+ return response_json.get("region")
+ except Exception:
+ # Not running on an ec2 instance.
+ return None
+
+
+# Detemine region name and select bucket accordingly.
+AWS_REGION = get_region()
+if AWS_REGION in ["us-east-2", "us-west-2", "eu-west-1"]:
+ # Select bucket in same region.
+ DEFAULT_S3_PATH = f"s3://object-centric-datasets-{AWS_REGION}"
+ # fanke aws s3 ls s3://object-centric-datasets-us-west-2/clevr_with_masks_new_splits
+ # aws s3 cp --recursive s3://object-centric-datasets-us-west-2/clevr_with_masks_new_splits ./clevr_with_masks_new_splits
+ # # aws s3 ls s3://object-centric-datasets-us-west-2/
+ # aws s3 cp --recursive s3://object-centric-datasets-us-west-2/movi_e/ movi_e
+else:
+ # Use MRAP to find closest bucket.
+ DEFAULT_S3_PATH = "s3://arn:aws:s3::436622332146:accesspoint/m6p4hmmybeu97.mrap"
+
+
+@dataclasses.dataclass
+class DataModuleConfig:
+ """Base class for PyTorch Lightning DataModules.
+
+ This class does not actually do anything but ensures that datasets behave like pytorch lightning
+ datamodules.
+ """
+
+
+def dataset_prefix(path):
+ prefix = os.environ.get("DATASET_PREFIX")
+ if prefix:
+ return f"{prefix}/{path}"
+ # Use the path to the multi-region bucket if no override is specified.
+ return f"pipe:aws s3 cp --quiet {DEFAULT_S3_PATH}/{path} -"
+
+
+def read_yaml(path):
+ with open(to_absolute_path(path), "r") as f:
+ return yaml.safe_load(f)
+
+
+WebdatasetDataModuleConfig = builds(
+ datasets.WebdatasetDataModule, populate_full_signature=True, builds_bases=(DataModuleConfig,)
+)
+DummyDataModuleConfig = builds(
+ datasets.DummyDataModule, populate_full_signature=True, builds_bases=(DataModuleConfig,)
+)
+
+
+def register_configs(config_store):
+ config_store.store(group="schemas", name="dataset", node=DataModuleConfig)
+ config_store.store(group="dataset", name="webdataset", node=WebdatasetDataModuleConfig)
+ config_store.store(group="dataset", name="dummy_dataset", node=DummyDataModuleConfig)
+
+
+def register_resolvers(omegaconf):
+ omegaconf.register_new_resolver("dataset_prefix", dataset_prefix)
+ omegaconf.register_new_resolver("read_yaml", read_yaml)
diff --git a/ocl/config/feature_extractors.py b/ocl/config/feature_extractors.py
new file mode 100644
index 0000000..415f2cb
--- /dev/null
+++ b/ocl/config/feature_extractors.py
@@ -0,0 +1,63 @@
+"""Configurations for feature extractors."""
+import dataclasses
+
+from hydra_zen import make_custom_builds_fn
+
+from ocl import feature_extractors
+
+
+@dataclasses.dataclass
+class FeatureExtractorConfig:
+ """Base class for PyTorch Lightning DataModules.
+
+ This class does not actually do anything but ensures that feature extractors give outputs of
+ a defined structure.
+ """
+
+ pass
+
+
+builds_feature_extractor = make_custom_builds_fn(
+ populate_full_signature=True,
+)
+
+TimmFeatureExtractorConfig = builds_feature_extractor(
+ feature_extractors.TimmFeatureExtractor,
+ builds_bases=(FeatureExtractorConfig,),
+)
+SlotAttentionFeatureExtractorConfig = builds_feature_extractor(
+ feature_extractors.SlotAttentionFeatureExtractor,
+ builds_bases=(FeatureExtractorConfig,),
+)
+DVAEFeatureExtractorConfig = builds_feature_extractor(
+ feature_extractors.DVAEFeatureExtractor,
+ builds_bases=(FeatureExtractorConfig,),
+)
+SAViFeatureExtractorConfig = builds_feature_extractor(
+ feature_extractors.SAViFeatureExtractor,
+ builds_bases=(FeatureExtractorConfig,),
+)
+
+
+def register_configs(config_store):
+ config_store.store(group="schemas", name="feature_extractor", node=FeatureExtractorConfig)
+ config_store.store(
+ group="feature_extractor",
+ name="timm_model",
+ node=TimmFeatureExtractorConfig,
+ )
+ config_store.store(
+ group="feature_extractor",
+ name="slot_attention",
+ node=SlotAttentionFeatureExtractorConfig,
+ )
+ config_store.store(
+ group="feature_extractor",
+ name="savi",
+ node=SAViFeatureExtractorConfig,
+ )
+ config_store.store(
+ group="feature_extractor",
+ name="dvae",
+ node=DVAEFeatureExtractorConfig,
+ )
diff --git a/ocl/config/metrics.py b/ocl/config/metrics.py
new file mode 100644
index 0000000..c41d98a
--- /dev/null
+++ b/ocl/config/metrics.py
@@ -0,0 +1,157 @@
+"""Register metric related configs."""
+import dataclasses
+
+from hydra_zen import builds, make_custom_builds_fn
+
+from ocl import metrics
+@dataclasses.dataclass
+class MetricConfig:
+ """Base class for metrics."""
+ pass
+
+
+builds_metric = make_custom_builds_fn(
+ populate_full_signature=True,
+)
+
+TensorStatisticConfig = builds_metric(metrics.TensorStatistic, builds_bases=(MetricConfig,))
+
+
+TorchmetricsWrapperConfig = builds_metric(metrics.TorchmetricsWrapper, builds_bases=(MetricConfig,))
+
+
+PurityMetricConfig = builds_metric(
+ metrics.MutualInfoAndPairCounting,
+ metric_name="purity",
+ builds_bases=(MetricConfig,),
+)
+PrecisionMetricConfig = builds_metric(
+ metrics.MutualInfoAndPairCounting,
+ metric_name="precision",
+ builds_bases=(MetricConfig,),
+)
+RecallMetricConfig = builds_metric(
+ metrics.MutualInfoAndPairCounting,
+ metric_name="recall",
+ builds_bases=(MetricConfig,),
+)
+F1MetricConfig = builds_metric(
+ metrics.MutualInfoAndPairCounting,
+ metric_name="f1",
+ builds_bases=(MetricConfig,),
+)
+AMIMetricConfig = builds_metric(
+ metrics.MutualInfoAndPairCounting,
+ metric_name="ami",
+ builds_bases=(MetricConfig,),
+)
+NMIMetricConfig = builds_metric(
+ metrics.MutualInfoAndPairCounting,
+ metric_name="nmi",
+ builds_bases=(MetricConfig,),
+)
+ARISklearnMetricConfig = builds_metric(
+ metrics.MutualInfoAndPairCounting,
+ metric_name="ari_sklearn",
+ builds_bases=(MetricConfig,),
+)
+
+ARIMetricConfig = builds_metric(metrics.ARIMetric, builds_bases=(MetricConfig,))
+PatchARIMetricConfig = builds_metric(
+ metrics.PatchARIMetric,
+ builds_bases=(MetricConfig,),
+)
+UnsupervisedMaskIoUMetricConfig = builds_metric(
+ metrics.UnsupervisedMaskIoUMetric,
+ builds_bases=(MetricConfig,),
+)
+MOTMetricConfig = builds_metric(
+ metrics.MOTMetric,
+ builds_bases=(MetricConfig,),
+)
+MaskCorLocMetricConfig = builds_metric(
+ metrics.UnsupervisedMaskIoUMetric,
+ matching="best_overlap",
+ correct_localization=True,
+ builds_bases=(MetricConfig,),
+)
+AverageBestOverlapMetricConfig = builds_metric(
+ metrics.UnsupervisedMaskIoUMetric,
+ matching="best_overlap",
+ builds_bases=(MetricConfig,),
+)
+BestOverlapObjectRecoveryMetricConfig = builds_metric(
+ metrics.UnsupervisedMaskIoUMetric,
+ matching="best_overlap",
+ compute_discovery_fraction=True,
+ builds_bases=(MetricConfig,),
+)
+UnsupervisedBboxIoUMetricConfig = builds_metric(
+ metrics.UnsupervisedBboxIoUMetric,
+ builds_bases=(MetricConfig,),
+)
+BboxCorLocMetricConfig = builds_metric(
+ metrics.UnsupervisedBboxIoUMetric,
+ matching="best_overlap",
+ correct_localization=True,
+ builds_bases=(MetricConfig,),
+)
+BboxRecallMetricConfig = builds_metric(
+ metrics.UnsupervisedBboxIoUMetric,
+ matching="best_overlap",
+ compute_discovery_fraction=True,
+ builds_bases=(MetricConfig,),
+)
+
+
+DatasetSemanticMaskIoUMetricConfig = builds_metric(metrics.DatasetSemanticMaskIoUMetric)
+
+SklearnClusteringConfig = builds(
+ metrics.SklearnClustering,
+ populate_full_signature=True,
+)
+
+
+def register_configs(config_store):
+ config_store.store(group="metrics", name="tensor_statistic", node=TensorStatisticConfig)
+
+ config_store.store(group="metrics", name="torchmetric", node=TorchmetricsWrapperConfig)
+ config_store.store(group="metrics", name="ami_metric", node=AMIMetricConfig)
+ config_store.store(group="metrics", name="nmi_metric", node=NMIMetricConfig)
+ config_store.store(group="metrics", name="ari_sklearn_metric", node=ARISklearnMetricConfig)
+ config_store.store(group="metrics", name="purity_metric", node=PurityMetricConfig)
+ config_store.store(group="metrics", name="precision_metric", node=PrecisionMetricConfig)
+ config_store.store(group="metrics", name="recall_metric", node=RecallMetricConfig)
+ config_store.store(group="metrics", name="f1_metric", node=F1MetricConfig)
+
+ config_store.store(group="metrics", name="mot_metric", node=MOTMetricConfig)
+ config_store.store(group="metrics", name="ari_metric", node=ARIMetricConfig)
+ config_store.store(group="metrics", name="patch_ari_metric", node=PatchARIMetricConfig)
+ config_store.store(
+ group="metrics", name="unsupervised_mask_iou_metric", node=UnsupervisedMaskIoUMetricConfig
+ )
+ config_store.store(group="metrics", name="mask_corloc_metric", node=MaskCorLocMetricConfig)
+ config_store.store(
+ group="metrics", name="average_best_overlap_metric", node=AverageBestOverlapMetricConfig
+ )
+ config_store.store(
+ group="metrics",
+ name="best_overlap_object_recovery_metric",
+ node=BestOverlapObjectRecoveryMetricConfig,
+ )
+ config_store.store(
+ group="metrics", name="unsupervised_bbox_iou_metric", node=UnsupervisedBboxIoUMetricConfig
+ )
+ config_store.store(group="metrics", name="bbox_corloc_metric", node=BboxCorLocMetricConfig)
+ config_store.store(group="metrics", name="bbox_recall_metric", node=BboxRecallMetricConfig)
+
+ config_store.store(
+ group="metrics",
+ name="dataset_semantic_mask_iou",
+ node=DatasetSemanticMaskIoUMetricConfig,
+ )
+ config_store.store(
+ group="clustering",
+ name="sklearn_clustering",
+ node=SklearnClusteringConfig,
+ )
diff --git a/ocl/config/neural_networks.py b/ocl/config/neural_networks.py
new file mode 100644
index 0000000..dde8771
--- /dev/null
+++ b/ocl/config/neural_networks.py
@@ -0,0 +1,36 @@
+"""Configs for neural networks."""
+import omegaconf
+from hydra_zen import builds
+
+from ocl import neural_networks
+
+MLPBuilderConfig = builds(
+ neural_networks.build_mlp,
+ features=omegaconf.MISSING,
+ zen_partial=True,
+ populate_full_signature=True,
+)
+TransformerEncoderBuilderConfig = builds(
+ neural_networks.build_transformer_encoder,
+ n_layers=omegaconf.MISSING,
+ n_heads=omegaconf.MISSING,
+ zen_partial=True,
+ populate_full_signature=True,
+)
+TransformerDecoderBuilderConfig = builds(
+ neural_networks.build_transformer_decoder,
+ n_layers=omegaconf.MISSING,
+ n_heads=omegaconf.MISSING,
+ zen_partial=True,
+ populate_full_signature=True,
+)
+
+
+def register_configs(config_store):
+ config_store.store(group="neural_networks", name="mlp", node=MLPBuilderConfig)
+ config_store.store(
+ group="neural_networks", name="transformer_encoder", node=TransformerEncoderBuilderConfig
+ )
+ config_store.store(
+ group="neural_networks", name="transformer_decoder", node=TransformerDecoderBuilderConfig
+ )
diff --git a/ocl/config/optimizers.py b/ocl/config/optimizers.py
new file mode 100644
index 0000000..76ac6b7
--- /dev/null
+++ b/ocl/config/optimizers.py
@@ -0,0 +1,38 @@
+"""Pytorch optimizers."""
+import dataclasses
+
+import torch.optim
+from hydra_zen import make_custom_builds_fn
+
+
+@dataclasses.dataclass
+class OptimizerConfig:
+ pass
+
+
+# TODO(hornmax): We cannot automatically extract type information from the torch SGD implementation,
+# thus we define it manually here.
+@dataclasses.dataclass
+class SGDConfig(OptimizerConfig):
+ learning_rate: float
+ momentum: float = 0.0
+ dampening: float = 0.0
+ nestov: bool = False
+ _target_: str = "hydra_zen.funcs.zen_processing"
+ _zen_target: str = "torch.optim.SGD"
+ _zen_partial: bool = True
+
+
+pbuilds = make_custom_builds_fn(
+ zen_partial=True,
+ populate_full_signature=True,
+)
+
+AdamConfig = pbuilds(torch.optim.Adam, builds_bases=(OptimizerConfig,))
+AdamWConfig = pbuilds(torch.optim.AdamW, builds_bases=(OptimizerConfig,))
+
+
+def register_configs(config_store):
+ config_store.store(group="optimizers", name="sgd", node=SGDConfig)
+ config_store.store(group="optimizers", name="adam", node=AdamConfig)
+ config_store.store(group="optimizers", name="adamw", node=AdamWConfig)
diff --git a/ocl/config/perceptual_groupings.py b/ocl/config/perceptual_groupings.py
new file mode 100644
index 0000000..9a0a7ae
--- /dev/null
+++ b/ocl/config/perceptual_groupings.py
@@ -0,0 +1,29 @@
+"""Perceptual grouping models."""
+import dataclasses
+
+from hydra_zen import builds
+
+from ocl import perceptual_grouping
+
+
+@dataclasses.dataclass
+class PerceptualGroupingConfig:
+ """Configuration class of perceptual grouping models."""
+
+
+SlotAttentionConfig = builds(
+ perceptual_grouping.SlotAttentionGrouping,
+ builds_bases=(PerceptualGroupingConfig,),
+ populate_full_signature=True,
+)
+SlotAttentionGumbelV1Config = builds(
+ perceptual_grouping.SlotAttentionGroupingGumbelV1,
+ builds_bases=(PerceptualGroupingConfig,),
+ populate_full_signature=True,
+)
+
+
+def register_configs(config_store):
+ config_store.store(group="schemas", name="perceptual_grouping", node=PerceptualGroupingConfig)
+ config_store.store(group="perceptual_grouping", name="slot_attention", node=SlotAttentionConfig)
+ config_store.store(group="perceptual_grouping", name="slot_attention_gumbel_v1", node=SlotAttentionGumbelV1Config)
diff --git a/ocl/config/plugins.py b/ocl/config/plugins.py
new file mode 100644
index 0000000..e8528f0
--- /dev/null
+++ b/ocl/config/plugins.py
@@ -0,0 +1,246 @@
+"""Configuration of plugins."""
+import dataclasses
+import functools
+
+import hydra_zen
+from hydra_zen import builds
+from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
+
+from ocl import plugins, scheduling
+from ocl.config.optimizers import OptimizerConfig
+
+
+@dataclasses.dataclass
+class PluginConfig:
+ """Base class for plugin configurations."""
+
+ pass
+
+
+@dataclasses.dataclass
+class LRSchedulerConfig:
+ pass
+
+
+def exponential_decay_with_optional_warmup(
+ optimizer, decay_rate: float = 1.0, decay_steps: int = 10000, warmup_steps: int = 0
+):
+ """Return pytorch lighting optimizer configuration for exponential decay with optional warmup.
+
+ Returns:
+ Dict with structure compatible with ptl. See
+ https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.core.lightning.html#pytorch_lightning.core.lightning.LightningModule.configure_optimizers
+ """
+ decay_fn = functools.partial(
+ scheduling.exp_decay_with_warmup_fn,
+ decay_rate=decay_rate,
+ decay_steps=decay_steps,
+ warmup_steps=warmup_steps,
+ )
+
+ return {"lr_scheduler": {"scheduler": LambdaLR(optimizer, decay_fn), "interval": "step"}}
+
+
+def exponential_decay_after_optional_warmup(
+ optimizer, decay_rate: float = 1.0, decay_steps: int = 10000, warmup_steps: int = 0
+):
+ """Return pytorch lighting optimizer configuration for exponential decay after optional warmup.
+
+ Returns:
+ Dict with structure compatible with ptl. See
+ https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.core.lightning.html#pytorch_lightning.core.lightning.LightningModule.configure_optimizers
+ """
+ decay_fn = functools.partial(
+ scheduling.exp_decay_after_warmup_fn,
+ decay_rate=decay_rate,
+ decay_steps=decay_steps,
+ warmup_steps=warmup_steps,
+ )
+
+ return {"lr_scheduler": {"scheduler": LambdaLR(optimizer, decay_fn), "interval": "step"}}
+
+
+def plateau_decay(optimizer, decay_rate: float = 1.0, patience: int = 10):
+ """Return pytorch lighting optimizer configuration for plato decay.
+
+ Returns:
+ Dict with structure compatible with ptl. See
+ https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.core.lightning.html#pytorch_lightning.core.lightning.LightningModule.configure_optimizers
+ """
+ plateau_scheduler = ReduceLROnPlateau(
+ optimizer=optimizer, mode="min", factor=decay_rate, patience=patience
+ )
+ return {
+ "lr_scheduler": {
+ "scheduler": plateau_scheduler,
+ "interval": "epoch",
+ "monitor": "val/loss_total",
+ }
+ }
+
+
+def cosine_annealing_with_optional_warmup(
+ optimizer,
+ T_max: int = 100000,
+ eta_min: float = 0.0,
+ warmup_steps: int = 0,
+ error_on_exceeding_steps: bool = True,
+):
+ """Return pytorch lighting optimizer configuration for cosine annealing with warmup.
+
+ Returns:
+ Dict with structure compatible with ptl. See
+ https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.core.lightning.html#pytorch_lightning.core.lightning.LightningModule.configure_optimizers
+ """
+ return {
+ "lr_scheduler": {
+ "scheduler": scheduling.CosineAnnealingWithWarmup(
+ optimizer,
+ T_max,
+ eta_min=eta_min,
+ warmup_steps=warmup_steps,
+ error_on_exceeding_steps=error_on_exceeding_steps,
+ ),
+ "interval": "step",
+ }
+ }
+
+
+PlateauDecayLR = builds(
+ plateau_decay,
+ zen_partial=True,
+ populate_full_signature=True,
+ builds_bases=(LRSchedulerConfig,),
+)
+
+ExpDecayLR = builds(
+ exponential_decay_with_optional_warmup,
+ zen_partial=True,
+ populate_full_signature=True,
+ builds_bases=(LRSchedulerConfig,),
+)
+
+ExpDecayAfterWarmupLR = builds(
+ exponential_decay_after_optional_warmup,
+ zen_partial=True,
+ populate_full_signature=True,
+ builds_bases=(LRSchedulerConfig,),
+)
+
+CosineAnnealingLR = builds(
+ cosine_annealing_with_optional_warmup,
+ zen_partial=True,
+ populate_full_signature=True,
+ builds_bases=(LRSchedulerConfig,),
+)
+
+
+@dataclasses.dataclass
+class HPSchedulerConfig:
+ """Base class for hyperparameter scheduler configuration."""
+
+
+LinearHPSchedulerConfig = builds(
+ scheduling.LinearHPScheduler,
+ builds_bases=(HPSchedulerConfig,),
+ populate_full_signature=True,
+)
+StepHPSchedulerConfig = builds(
+ scheduling.StepHPScheduler,
+ builds_bases=(HPSchedulerConfig,),
+ populate_full_signature=True,
+)
+CosineAnnealingHPSchedulerConfig = builds(
+ scheduling.CosineAnnealingHPScheduler,
+ builds_bases=(HPSchedulerConfig,),
+ populate_full_signature=True,
+)
+
+
+builds_plugin = hydra_zen.make_custom_builds_fn(
+ populate_full_signature=True,
+)
+OptimizationConfig = builds_plugin(
+ plugins.Optimization,
+ optimizer=OptimizerConfig,
+ lr_scheduler=LRSchedulerConfig,
+ builds_bases=(PluginConfig,),
+)
+SingleElementPreprocessingConfig = builds_plugin(
+ plugins.SingleElementPreprocessing, builds_bases=(PluginConfig,)
+)
+MultiElementPreprocessingConfig = builds_plugin(
+ plugins.MultiElementPreprocessing, builds_bases=(PluginConfig,)
+)
+DataPreprocessingConfig = builds_plugin(plugins.DataPreprocessing, builds_bases=(PluginConfig,))
+SubsetDatasetConfig = builds_plugin(plugins.SubsetDataset, builds_bases=(PluginConfig,))
+SampleFramesFromVideoConfig = builds_plugin(
+ plugins.SampleFramesFromVideo, builds_bases=(PluginConfig,)
+)
+SplitConsecutiveFramesConfig = builds_plugin(
+ plugins.SplitConsecutiveFrames, builds_bases=(PluginConfig,)
+)
+RandomStridedWindowConfig = builds_plugin(plugins.RandomStridedWindow, builds_bases=(PluginConfig,))
+RenameFieldsConfig = builds_plugin(plugins.RenameFields, builds_bases=(PluginConfig,))
+SpatialSlidingWindowConfig = builds_plugin(
+ plugins.SpatialSlidingWindow, builds_bases=(PluginConfig,)
+)
+
+
+def register_configs(config_store):
+ config_store.store(group="schemas", name="lr_scheduler", node=LRSchedulerConfig)
+ config_store.store(group="lr_schedulers", name="exponential_decay", node=ExpDecayLR)
+ config_store.store(
+ group="lr_schedulers", name="exponential_decay_after_warmup", node=ExpDecayAfterWarmupLR
+ )
+ config_store.store(group="lr_schedulers", name="plateau_decay", node=PlateauDecayLR)
+ config_store.store(group="lr_schedulers", name="cosine_annealing", node=CosineAnnealingLR)
+
+ config_store.store(group="schemas", name="hp_scheduler", node=HPSchedulerConfig)
+ config_store.store(group="hp_schedulers", name="linear", node=LinearHPSchedulerConfig)
+ config_store.store(group="hp_schedulers", name="step", node=StepHPSchedulerConfig)
+ config_store.store(
+ group="hp_schedulers", name="cosine_annealing", node=CosineAnnealingHPSchedulerConfig
+ )
+
+ config_store.store(group="schemas", name="plugin", node=PluginConfig)
+ config_store.store(group="plugins", name="optimization", node=OptimizationConfig)
+ config_store.store(
+ group="plugins",
+ name="single_element_preprocessing",
+ node=SingleElementPreprocessingConfig,
+ )
+ config_store.store(
+ group="plugins",
+ name="multi_element_preprocessing",
+ node=MultiElementPreprocessingConfig,
+ )
+ config_store.store(
+ group="plugins",
+ name="data_preprocessing",
+ node=DataPreprocessingConfig,
+ )
+ config_store.store(group="plugins", name="subset_dataset", node=SubsetDatasetConfig)
+ config_store.store(
+ group="plugins", name="sample_frames_from_video", node=SampleFramesFromVideoConfig
+ )
+ config_store.store(
+ group="plugins", name="split_consecutive_frames", node=SplitConsecutiveFramesConfig
+ )
+ config_store.store(group="plugins", name="random_strided_window", node=RandomStridedWindowConfig)
+ config_store.store(group="plugins", name="rename_fields", node=RenameFieldsConfig)
+ config_store.store(
+ group="plugins", name="spatial_sliding_window", node=SpatialSlidingWindowConfig
+ )
+
+
+def _torchvision_interpolation_mode(mode):
+ import torchvision
+
+ return torchvision.transforms.InterpolationMode[mode.upper()]
+
+
+def register_resolvers(omegaconf):
+ omegaconf.register_new_resolver(
+ "torchvision_interpolation_mode", _torchvision_interpolation_mode
+ )
diff --git a/ocl/config/predictor.py b/ocl/config/predictor.py
new file mode 100644
index 0000000..ffae85c
--- /dev/null
+++ b/ocl/config/predictor.py
@@ -0,0 +1,23 @@
+"""Perceptual grouping models."""
+import dataclasses
+
+from hydra_zen import builds
+
+from ocl import predictor
+
+
+@dataclasses.dataclass
+class PredictorConfig:
+ """Configuration class of Predictor."""
+
+
+TransitionConfig = builds(
+ predictor.Predictor,
+ builds_bases=(PredictorConfig,),
+ populate_full_signature=True,
+)
+
+
+def register_configs(config_store):
+ config_store.store(group="schemas", name="predictor", node=PredictorConfig)
+ config_store.store(group="predictor", name="multihead_attention", node=TransitionConfig)
diff --git a/ocl/config/utils.py b/ocl/config/utils.py
new file mode 100644
index 0000000..be862f7
--- /dev/null
+++ b/ocl/config/utils.py
@@ -0,0 +1,108 @@
+"""Utility functions useful for configuration."""
+import ast
+from typing import Any, Callable
+
+from hydra_zen import builds
+
+from ocl.config.feature_extractors import FeatureExtractorConfig
+from ocl.config.perceptual_groupings import PerceptualGroupingConfig
+from ocl.distillation import EMASelfDistillation
+from ocl.utils.masking import CreateSlotMask
+from ocl.utils.routing import Combined, Recurrent
+import torch
+
+def lambda_string_to_function(function_string: str) -> Callable[..., Any]:
+ """Convert string of the form "lambda x: x" into a callable Python function."""
+ # This is a bit hacky but ensures that the syntax of the input is correct and contains
+ # a valid lambda function definition without requiring to run `eval`.
+ parsed = ast.parse(function_string)
+ is_lambda = isinstance(parsed.body[0], ast.Expr) and isinstance(parsed.body[0].value, ast.Lambda)
+ if not is_lambda:
+ raise ValueError(f"'{function_string}' is not a valid lambda definition.")
+
+ return eval(function_string)
+
+
+class ConfigDefinedLambda:
+ """Lambda function defined in the config.
+
+ This allows lambda functions defined in the config to be pickled.
+ """
+
+ def __init__(self, function_string: str):
+ self.__setstate__(function_string)
+
+ def __getstate__(self) -> str:
+ return self.function_string
+
+ def __setstate__(self, function_string: str):
+ self.function_string = function_string
+ self._fn = lambda_string_to_function(function_string)
+
+ def __call__(self, *args, **kwargs):
+ return self._fn(*args, **kwargs)
+
+
+def eval_lambda(function_string, *args):
+ lambda_fn = lambda_string_to_function(function_string)
+ return lambda_fn(*args)
+
+
+FunctionConfig = builds(ConfigDefinedLambda, populate_full_signature=True)
+
+# Inherit from all so it can be used in place of any module.
+CombinedConfig = builds(
+ Combined,
+ populate_full_signature=True,
+ builds_bases=(FeatureExtractorConfig, PerceptualGroupingConfig),
+)
+RecurrentConfig = builds(
+ Recurrent,
+ populate_full_signature=True,
+ builds_bases=(FeatureExtractorConfig, PerceptualGroupingConfig),
+)
+CreateSlotMaskConfig = builds(CreateSlotMask, populate_full_signature=True)
+
+
+EMASelfDistillationConfig = builds(
+ EMASelfDistillation,
+ populate_full_signature=True,
+ builds_bases=(FeatureExtractorConfig, PerceptualGroupingConfig),
+)
+
+
+def make_slice(expr):
+ if isinstance(expr, int):
+ return expr
+
+ pieces = [s and int(s) or None for s in expr.split(":")]
+ if len(pieces) == 1:
+ return slice(pieces[0], pieces[0] + 1)
+ else:
+ return slice(*pieces)
+
+
+def slice_string(string: str, split_char: str, slice_str: str) -> str:
+ """Split a string according to a split_char and slice.
+
+ If the output contains more than one element, join these using the split char again.
+ """
+ sl = make_slice(slice_str)
+ res = string.split(split_char)[sl]
+ if isinstance(res, list):
+ res = split_char.join(res)
+ return res
+
+
+def register_configs(config_store):
+ config_store.store(group="schemas", name="lambda_fn", node=FunctionConfig)
+ config_store.store(group="utils", name="combined", node=CombinedConfig)
+ config_store.store(group="utils", name="selfdistillation", node=EMASelfDistillationConfig)
+ config_store.store(group="utils", name="recurrent", node=RecurrentConfig)
+ config_store.store(group="utils", name="create_slot_mask", node=CreateSlotMaskConfig)
+
+
+def register_resolvers(omegaconf):
+ omegaconf.register_new_resolver("lambda_fn", ConfigDefinedLambda)
+ omegaconf.register_new_resolver("eval_lambda", eval_lambda)
+ omegaconf.register_new_resolver("slice", slice_string)
diff --git a/ocl/consistency.py b/ocl/consistency.py
new file mode 100644
index 0000000..384b05d
--- /dev/null
+++ b/ocl/consistency.py
@@ -0,0 +1,49 @@
+"""Modules to compute the IoU matching cost and solve the corresponding LSAP."""
+import numpy as np
+import torch
+from scipy.optimize import linear_sum_assignment
+from torch import nn
+
+
+class HungarianMatcher(nn.Module):
+ """This class computes an assignment between the targets and the predictions of the network."""
+
+ @torch.no_grad()
+ def forward(self, mask_preds, mask_targets):
+ """Performs the matching.
+
+ Params:
+ mask_preds: Tensor of dim [batch_size, n_objects, N, N] with the predicted masks
+ mask_targets: Tensor of dim [batch_size, n_objects, N, N]
+ with the target masks from another augmentation
+
+ Returns:
+ A list of size batch_size, containing tuples of (index_i, index_j) where:
+ - index_i is the indices of the selected predictions
+ - index_j is the indices of the corresponding selected targets
+ """
+ bs, n_objects, _, _ = mask_preds.shape
+ # Compute the iou cost betwen masks
+ cost_iou = -get_iou_matrix(mask_preds, mask_targets)
+ cost_iou = cost_iou.reshape(bs, n_objects, bs, n_objects).cpu()
+ self.costs = torch.stack([cost_iou[i, :, i, :][None] for i in range(bs)])
+ indices = [linear_sum_assignment(c[0]) for c in self.costs]
+ return torch.as_tensor(np.array(indices))
+
+
+def get_iou_matrix(preds, targets):
+
+ bs, n_objects, H, W = targets.shape
+ targets = targets.reshape(bs * n_objects, H * W).float()
+ preds = preds.reshape(bs * n_objects, H * W).float()
+
+ intersection = torch.matmul(targets, preds.t())
+ targets_area = targets.sum(dim=1).view(1, -1)
+ preds_area = preds.sum(dim=1).view(1, -1)
+ union = (targets_area.t() + preds_area) - intersection
+
+ return torch.where(
+ union == 0,
+ torch.tensor(0.0, device=targets.device),
+ intersection / union,
+ )
diff --git a/ocl/datasets.py b/ocl/datasets.py
new file mode 100644
index 0000000..25ad33a
--- /dev/null
+++ b/ocl/datasets.py
@@ -0,0 +1,454 @@
+"""Implementation of datasets."""
+import collections
+import logging
+import math
+import os
+from functools import partial
+from itertools import chain
+from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
+
+import numpy as np
+import pytorch_lightning as pl
+import torch
+import webdataset
+from torch.utils.data._utils import collate as torch_collate
+
+from ocl import base
+from ocl.hooks import FakeHooks, hook_implementation
+
+LOGGER = logging.getLogger(__name__)
+
+
+def filter_keys(d: dict, keys_to_keep=tuple):
+ """Filter dict for keys in keys_to_keep."""
+ # print("keys_to_keep")
+ # print(keys_to_keep)
+ keys_to_keep = ("_",) + keys_to_keep
+ return {
+ key: value
+ for key, value in d.items()
+ if any(key.startswith(prefix) for prefix in keys_to_keep)
+ }
+
+
+def combine_keys(list_of_key_tuples: Sequence[Tuple[str]]):
+ return tuple(set(chain.from_iterable(list_of_key_tuples)))
+
+
+class WebdatasetDataModule(pl.LightningDataModule):
+ """Imagenet Data Module."""
+
+ def __init__(
+ self,
+ train_shards: Optional[Union[str, List[str]]] = None,
+ val_shards: Optional[Union[str, List[str]]] = None,
+ test_shards: Optional[Union[str, List[str]]] = None,
+ batch_size: int = 32,
+ eval_batch_size: Optional[int] = None,
+ num_workers: int = 2,
+ train_size: Optional[int] = None,
+ val_size: Optional[int] = None,
+ test_size: Optional[int] = None,
+ shuffle_train: bool = True,
+ shuffle_buffer_size: int = 3000,
+ use_autopadding: bool = False,
+ continue_on_shard_error: bool = False,
+ hooks: Optional[base.PluggyHookRelay] = None,
+ ):
+ super().__init__()
+ if train_shards is None and val_shards is None and test_shards is None:
+ raise ValueError("No split was specified. Need to specify at least one split.")
+ self.train_shards = train_shards
+ self.val_shards = val_shards
+ self.test_shards = test_shards
+ self.train_size = train_size
+ self.val_size = val_size
+ self.test_size = test_size
+ self.batch_size = batch_size
+ self.eval_batch_size = eval_batch_size if eval_batch_size is not None else batch_size
+ self.num_workers = num_workers
+ self.shuffle_train = shuffle_train
+ self.shuffle_buffer_size = shuffle_buffer_size
+ self.continue_on_shard_error = continue_on_shard_error
+ self.hooks = hooks if hooks else FakeHooks()
+
+ if use_autopadding:
+ self.collate_fn = collate_with_autopadding
+ else:
+ self.collate_fn = collate_with_batch_size
+
+ @staticmethod
+ def _remove_extensions(input_dict):
+ def _remove_extension(name: str):
+ if name.endswith(".gz"):
+ # Webdataset automatically decompresses these, we want to remove two layers of
+ # extensions due to that.
+ name = os.path.splitext(name)[0]
+ return os.path.splitext(name)[0]
+
+ return {_remove_extension(name): value for name, value in input_dict.items()}
+
+ @hook_implementation
+ def on_train_epoch_start(self, model) -> None:
+ """Set environment variables required for better shuffling."""
+ # Required for shuffling of instances across workers, see `epoch_shuffle` parameter of
+ # `webdataset.PytorchShardList`.
+ os.environ["WDS_EPOCH"] = str(model.current_epoch)
+
+ def _create_webdataset(
+ self,
+ uri_expression: Union[str, List[str]],
+ shuffle=False,
+ n_datapoints: Optional[int] = None,
+ keys_to_keep: Sequence[str] = tuple(),
+ transforms: Sequence[Callable[[webdataset.Processor], webdataset.Processor]] = tuple(),
+ ):
+ shard_list = webdataset.PytorchShardList(
+ uri_expression, shuffle=shuffle, epoch_shuffle=shuffle
+ )
+
+ if self.continue_on_shard_error:
+ handler = webdataset.warn_and_continue
+ else:
+ handler = webdataset.reraise_exception
+ dataset = webdataset.WebDataset(shard_list, handler=handler)
+ # Discard unneeded properties of the elements prior to shuffling and decoding.
+ dataset = dataset.map(partial(filter_keys, keys_to_keep=keys_to_keep))
+
+ if shuffle:
+ dataset = dataset.shuffle(self.shuffle_buffer_size)
+
+ # Decode files and remove extensions from input as we already decoded the elements. This
+ # makes our pipeline invariant to the exact encoding used in the dataset.
+ dataset = dataset.decode("rgb8").map(WebdatasetDataModule._remove_extensions)
+
+ # Apply transforms
+ for transform in transforms:
+ dataset = transform(dataset)
+ return dataset.with_length(n_datapoints)
+
+ def _create_dataloader(self, dataset, batch_transforms, size, batch_size, partial_batches):
+ # Don't return partial batches during training as these give the partial samples a higher
+ # weight in the optimization than the other samples of the dataset.
+
+ # Apply batch transforms.
+ dataset = dataset.batched(
+ batch_size,
+ partial=partial_batches,
+ collation_fn=self.collate_fn,
+ )
+ for transform in batch_transforms:
+ dataset = transform(dataset)
+ dataloader = webdataset.WebLoader(
+ dataset,
+ num_workers=self.num_workers,
+ batch_size=None,
+ )
+
+ if size:
+ # This is required for ddp training as we otherwise cannot guarantee that each worker
+ # gets the same number of batches.
+ equalized_size: int
+ if partial_batches:
+ # Round up in the case of partial batches.
+ equalized_size = int(math.ceil(size / batch_size))
+ else:
+ equalized_size = size // batch_size
+
+ dataloader = dataloader.ddp_equalize(equalized_size, with_length=True)
+ else:
+ LOGGER.warning(
+ "Size not provided in the construction of webdataset. "
+ "This may lead to problems when running distributed training."
+ )
+ return dataloader
+
+ def train_data_iterator(self):
+ if self.train_shards is None:
+ raise ValueError("Can not create train_data_iterator. No training split was specified.")
+ return self._create_webdataset(
+ self.train_shards,
+ shuffle=self.shuffle_train,
+ n_datapoints=self.train_size,
+ keys_to_keep=combine_keys(self.hooks.training_fields()),
+ transforms=self.hooks.training_transform(),
+ )
+
+ def train_dataloader(self):
+ return self._create_dataloader(
+ self.train_data_iterator(),
+ self.hooks.training_batch_transform(),
+ self.train_size,
+ self.batch_size,
+ partial_batches=False,
+ )
+
+ def val_data_iterator(self):
+ if self.val_shards is None:
+ raise ValueError("Can not create val_data_iterator. No val split was specified.")
+ return self._create_webdataset(
+ self.val_shards,
+ shuffle=False,
+ n_datapoints=self.val_size,
+ keys_to_keep=combine_keys(self.hooks.evaluation_fields()),
+ transforms=self.hooks.evaluation_transform(),
+ )
+
+ def val_dataloader(self):
+ return self._create_dataloader(
+ self.val_data_iterator(),
+ self.hooks.evaluation_batch_transform(),
+ self.val_size,
+ self.eval_batch_size,
+ partial_batches=True,
+ )
+
+ def test_data_iterator(self):
+ if self.test_shards is None:
+ raise ValueError("Can not create test_data_iterator. No test split was specified.")
+ return self._create_webdataset(
+ self.test_shards,
+ shuffle=False,
+ n_datapoints=self.test_size,
+ keys_to_keep=combine_keys(self.hooks.evaluation_fields()),
+ transforms=self.hooks.evaluation_transform(),
+ )
+
+ def test_dataloader(self):
+ return self._create_dataloader(
+ self.test_data_iterator(),
+ self.hooks.evaluation_batch_transform(),
+ self.test_size,
+ self.eval_batch_size,
+ partial_batches=True,
+ )
+
+
+class DummyDataModule(pl.LightningDataModule):
+ """Dataset providing dummy data for testing."""
+
+ def __init__(
+ self,
+ data_shapes: Dict[str, List[int]],
+ data_types: Dict[str, str],
+ hooks: Optional[base.PluggyHookRelay] = None,
+ batch_size: int = 4,
+ eval_batch_size: Optional[int] = None,
+ train_size: Optional[int] = None,
+ val_size: Optional[int] = None,
+ test_size: Optional[int] = None,
+ # Remaining args needed for compatibility with other datamodules
+ train_shards: Optional[str] = None,
+ val_shards: Optional[str] = None,
+ test_shards: Optional[str] = None,
+ num_workers: Optional[int] = None,
+ ):
+ super().__init__()
+ self.data_shapes = {key: list(shape) for key, shape in data_shapes.items()}
+ self.data_types = data_types
+ self.hooks = hooks if hooks else FakeHooks
+ self.batch_size = batch_size
+ self.eval_batch_size = eval_batch_size if eval_batch_size is not None else batch_size
+
+ self.train_size = train_size
+ if self.train_size is None:
+ self.train_size = 3 * batch_size + 1
+ self.val_size = val_size
+ if self.val_size is None:
+ self.val_size = 2 * batch_size
+ self.test_size = test_size
+ if self.test_size is None:
+ self.test_size = 2 * batch_size
+
+ @staticmethod
+ def _get_random_data_for_dtype(dtype: str, shape: List[int]):
+ if dtype == "image":
+ return np.random.randint(0, 256, size=shape, dtype=np.uint8)
+ elif dtype == "binary":
+ return np.random.randint(0, 2, size=shape, dtype=bool)
+ elif dtype == "uniform":
+ return np.random.rand(*shape).astype(np.float32)
+ elif dtype.startswith("categorical_"):
+ bounds = [int(b) for b in dtype.split("_")[1:]]
+ if len(bounds) == 1:
+ lower, upper = 0, bounds[0]
+ else:
+ lower, upper = bounds
+ np_dtype = np.uint8 if upper <= 256 else np.uint64
+ return np.random.randint(lower, upper, size=shape, dtype=np_dtype)
+ elif dtype.startswith("mask"):
+ categories = shape[1]
+ np_dtype = np.uint8 if categories <= 256 else np.uint64
+ slot_per_pixel = np.random.randint(
+ 0, categories, size=shape[:1] + shape[2:], dtype=np_dtype
+ )
+ return (
+ np.eye(categories)[slot_per_pixel.reshape(-1)]
+ .reshape(shape[:1] + shape[2:] + [categories])
+ .transpose([0, 4, 1, 2, 3])
+ )
+ else:
+ raise ValueError(f"Unsupported dtype `{dtype}`")
+
+ def _create_dataset(
+ self,
+ n_datapoints: int,
+ transforms: Sequence[Callable[[Any], Any]],
+ ):
+ class NumpyDataset(torch.utils.data.IterableDataset):
+ def __init__(self, data: Dict[str, np.ndarray], size: int):
+ super().__init__()
+ self.data = data
+ self.size = size
+
+ def __iter__(self):
+ for i in range(self.size):
+ elem = {key: value[i] for key, value in self.data.items()}
+ elem["__key__"] = str(i)
+ yield elem
+
+ data = {}
+ for key, shape in self.data_shapes.items():
+ data[key] = self._get_random_data_for_dtype(self.data_types[key], [n_datapoints] + shape)
+
+ dataset = webdataset.Processor(NumpyDataset(data, n_datapoints), lambda x: x)
+ for transform in transforms:
+ dataset = transform(dataset)
+
+ return dataset
+
+ def _create_dataloader(self, dataset, batch_size):
+ return torch.utils.data.DataLoader(
+ dataset, batch_size=batch_size, num_workers=0, collate_fn=collate_with_autopadding
+ )
+
+ def train_dataloader(self):
+ dataset = self._create_dataset(self.train_size, self.hooks.training_transform())
+ return self._create_dataloader(dataset, self.batch_size)
+
+ def val_dataloader(self):
+ dataset = self._create_dataset(self.val_size, self.hooks.evaluation_transform())
+ return self._create_dataloader(dataset, self.eval_batch_size)
+
+ def test_dataloader(self):
+ dataset = self._create_dataset(self.test_size, self.hooks.evaluation_transform())
+ return self._create_dataloader(dataset, self.eval_batch_size)
+
+
+def collate_with_batch_size(batch):
+ """Call default pytorch collate function yet for dict type input additionally add batch size."""
+ if isinstance(batch[0], collections.abc.Mapping):
+ out = torch_collate.default_collate(batch)
+ out["batch_size"] = len(batch)
+ return out
+ return torch_collate.default_collate(batch)
+
+
+def collate_with_autopadding(batch):
+ """Collate function that takes a batch of data and stacks it with a batch dimension.
+
+ In contrast to torch's collate function, this function automatically pads tensors of different
+ sizes with zeros such that they can be stacked.
+
+ Adapted from https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py.
+ """
+ elem = batch[0]
+ elem_type = type(elem)
+ # print(batch[0])
+ # print(elem_type)
+ if isinstance(elem, torch.Tensor):
+ out = None
+ # As most tensors will not need padding to be stacked, we first try to stack them normally
+ # and pad only if normal padding fails. This avoids explicitly checking whether all tensors
+ # have the same shape before stacking.
+ try:
+ if torch.utils.data.get_worker_info() is not None:
+ # If we're in a background process, concatenate directly into a
+ # shared memory tensor to avoid an extra copy
+ numel = sum(x.numel() for x in batch)
+ if len(batch) * elem.numel() != numel:
+ # Check whether resizing will fail because tensors have unequal sizes to avoid
+ # a memory allocation. This is a sufficient but not necessary condition, so it
+ # can happen that this check will not trigger when padding is necessary.
+ raise RuntimeError()
+ storage = elem.storage()._new_shared(numel)
+ out = elem.new(storage).resize_(len(batch), *elem.shape)
+ return torch.stack(batch, 0, out=out)
+ except RuntimeError:
+ # Stacking did not work. Try to pad tensors to the same dimensionality.
+ if not all(x.ndim == elem.ndim for x in batch):
+ raise ValueError("Tensors in batch have different number of dimensions.")
+
+ shapes = [x.shape for x in batch]
+ max_dims = [max(shape[idx] for shape in shapes) for idx in range(elem.ndim)]
+
+ paddings = []
+ for shape in shapes:
+ padding = []
+ # torch.nn.functional.pad wants padding from last to first dim, so go in reverse
+ for idx in reversed(range(len(shape))):
+ padding.append(0)
+ padding.append(max_dims[idx] - shape[idx])
+ paddings.append(padding)
+
+ batch_padded = [
+ torch.nn.functional.pad(x, pad, mode="constant", value=0.0)
+ for x, pad in zip(batch, paddings)
+ ]
+
+ if torch.utils.data.get_worker_info() is not None:
+ # If we're in a background process, concatenate directly into a
+ # shared memory tensor to avoid an extra copy
+ numel = sum(x.numel() for x in batch_padded)
+ storage = elem.storage()._new_shared(numel)
+ out = elem.new(storage).resize_(len(batch_padded), *batch_padded[0].shape)
+ return torch.stack(batch_padded, 0, out=out)
+ elif (
+ elem_type.__module__ == "numpy"
+ and elem_type.__name__ != "str_"
+ and elem_type.__name__ != "string_"
+ ):
+ if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap":
+ # array of string classes and object
+ if torch_collate.np_str_obj_array_pattern.search(elem.dtype.str) is not None:
+ raise TypeError(torch_collate.default_collate_err_msg_format.format(elem.dtype))
+
+ return collate_with_autopadding([torch.as_tensor(b) for b in batch])
+ elif elem.shape == (): # scalars
+ return torch.as_tensor(batch)
+ elif isinstance(elem, float):
+ return torch.tensor(batch, dtype=torch.float64)
+ elif isinstance(elem, int):
+ return torch.tensor(batch)
+ elif isinstance(elem, str):
+ return batch
+ elif isinstance(elem, collections.abc.Mapping):
+ out = {key: collate_with_autopadding([d[key] for d in batch]) for key in elem}
+ out["batch_size"] = len(batch)
+ try:
+ return elem_type(out)
+ except TypeError:
+ # The mapping type may not support `__init__(iterable)`.
+ return out
+ elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple
+ return elem_type(*(collate_with_autopadding(samples) for samples in zip(*batch)))
+ elif isinstance(elem, collections.abc.Sequence):
+ # check to make sure that the elements in batch have consistent size
+ it = iter(batch)
+ elem_size = len(next(it))
+ if not all(len(elem) == elem_size for elem in it):
+ raise RuntimeError("each element in list of batch should be of equal size")
+ transposed = list(zip(*batch)) # It may be accessed twice, so we use a list.
+
+ if isinstance(elem, tuple):
+ return [
+ collate_with_autopadding(samples) for samples in transposed
+ ] # Backwards compatibility.
+ else:
+ try:
+ return elem_type([collate_with_autopadding(samples) for samples in transposed])
+ except TypeError:
+ # The sequence type may not support `__init__(iterable)` (e.g., `range`).
+ return [collate_with_autopadding(samples) for samples in transposed]
+
+ raise TypeError(torch_collate.default_collate_err_msg_format.format(elem_type))
diff --git a/ocl/decoding.py b/ocl/decoding.py
new file mode 100644
index 0000000..e702a5e
--- /dev/null
+++ b/ocl/decoding.py
@@ -0,0 +1,1249 @@
+"""Implementation of different types of decoders."""
+import dataclasses
+import math
+from typing import Callable, Dict, Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torchtyping import TensorType
+
+from ocl.base import Instances
+from ocl.neural_networks.convenience import get_activation_fn
+from ocl.neural_networks.positional_embedding import SoftPositionEmbed
+from ocl.neural_networks.slate import Conv2dBlockWithGroupNorm
+from ocl.path_defaults import OBJECTS
+from ocl.utils.bboxes import box_cxcywh_to_xyxy
+from ocl.utils.resizing import resize_patches_to_image
+from ocl.utils.routing import RoutableMixin
+
+
+@dataclasses.dataclass
+class SimpleReconstructionOutput:
+ reconstruction: TensorType["batch_size", "channels", "height", "width"] # noqa: F821
+
+
+@dataclasses.dataclass
+class ReconstructionOutput:
+ reconstruction: TensorType["batch_size", "channels", "height", "width"] # noqa: F821
+ object_reconstructions: TensorType[
+ "batch_size", "n_objects", "channels", "height", "width" # noqa: F821
+ ]
+ masks: TensorType["batch_size", "n_objects", "height", "width"] # noqa: F821
+
+
+@dataclasses.dataclass
+class ReconstructionAmodalOutput:
+ reconstruction: TensorType["batch_size", "channels", "height", "width"] # noqa: F821
+ object_reconstructions: TensorType[
+ "batch_size", "n_objects", "channels", "height", "width" # noqa: F821
+ ]
+ masks: TensorType["batch_size", "n_objects", "height", "width"] # noqa: F821
+ masks_vis: TensorType["batch_size", "n_objects", "height", "width"] # noqa: F821
+ masks_eval: TensorType["batch_size", "n_objects", "height", "width"] # noqa: F821
+
+
+@dataclasses.dataclass
+class PatchReconstructionOutput:
+ reconstruction: TensorType["batch_size", "n_patches", "n_patch_features"] # noqa: F821
+ masks: TensorType["batch_size", "n_objects", "n_patches"] # noqa: F821
+ masks_as_image: Optional[
+ TensorType["batch_size", "n_objects", "height", "width"] # noqa: F821
+ ] = None
+ target: Optional[TensorType["batch_size", "n_patches", "n_patch_features"]] = None # noqa: F821
+
+
+@dataclasses.dataclass
+class DepthReconstructionOutput(ReconstructionOutput):
+ masks_amodal: Optional[
+ TensorType["batch_size", "n_objects", "height", "width"] # noqa: F821
+ ] = None
+ depth_map: Optional[TensorType["batch_size", "height", "width"]] = None # noqa: F821
+ object_depth_map: Optional[
+ TensorType["batch_size", "n_objects", "height", "width"] # noqa: F821
+ ] = None
+ densities: Optional[
+ TensorType["batch_size", "n_objects", "n_depth", "height", "width"] # noqa: F821
+ ] = None
+ colors: Optional[
+ TensorType["batch_size", "n_objects", "n_depth", "channels", "height", "width"] # noqa: F821
+ ] = None
+
+
+@dataclasses.dataclass
+class OpticalFlowPredictionTaskOutput:
+ predicted_flow: TensorType["batch_size", "channels", "height", "width"] # noqa: F821
+ object_flows: TensorType["batch_size", "n_objects", "channels", "height", "width"] # noqa: F821
+ masks: TensorType["batch_size", "n_objects", "height", "width"] # noqa: F821
+
+
+@dataclasses.dataclass
+class BBoxOutput:
+ bboxes: TensorType["batch_size", "n_objects", "box_dim"] # noqa: F821
+ classes: TensorType["batch_size", "n_objects", "num_classes"] # noqa: F821
+ ori_res_bboxes: TensorType["batch_size", "n_objects", "box_dim"] # noqa: F821
+ inference_obj_idxes: TensorType["batch_size", "n_objects"] # noqa: F821
+
+
+def build_grid_of_positions(resolution):
+ """Build grid of positions which can be used to create positions embeddings."""
+ ranges = [torch.linspace(0.0, 1.0, steps=res) for res in resolution]
+ grid = torch.meshgrid(*ranges, indexing="ij")
+ grid = torch.stack(grid, dim=-1)
+ grid = torch.reshape(grid, [resolution[0], resolution[1], -1])
+ return grid
+
+
+def get_slotattention_decoder_backbone(object_dim: int, output_dim: int = 4):
+ """Get CNN decoder backbone form the original slot attention paper."""
+ return nn.Sequential(
+ nn.ConvTranspose2d(object_dim, 64, 5, stride=2, padding=2, output_padding=1),
+ nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(64, 64, 5, stride=2, padding=2, output_padding=1),
+ nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(64, 64, 5, stride=2, padding=2, output_padding=1),
+ nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(64, 64, 5, stride=2, padding=2, output_padding=1),
+ nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(64, 64, 5, stride=1, padding=2, output_padding=0),
+ nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(64, output_dim, 3, stride=1, padding=1, output_padding=0),
+ )
+
+
+def get_savi_decoder_backbone(
+ object_dim: int,
+ output_dim: int = 4,
+ larger_input_arch: bool = False,
+ channel_multiplier: float = 1,
+):
+ """Get CNN decoder backbone form the slot attention for video paper."""
+ channels = int(64 * channel_multiplier)
+ if larger_input_arch:
+ output_stride = 2
+ output_padding = 1
+ else:
+ output_stride = 1
+ output_padding = 0
+ return nn.Sequential(
+ nn.ConvTranspose2d(object_dim, channels, 5, stride=2, padding=2, output_padding=1),
+ nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(channels, channels, 5, stride=2, padding=2, output_padding=1),
+ nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(channels, channels, 5, stride=2, padding=2, output_padding=1),
+ nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(
+ channels, channels, 5, stride=output_stride, padding=2, output_padding=output_padding
+ ),
+ nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(
+ channels,
+ output_dim,
+ 1,
+ stride=1,
+ padding=0,
+ output_padding=0,
+ ),
+ )
+
+
+def get_dvae_decoder(vocab_size: int, output_dim: int = 3):
+ """Get CNN decoder backbone for DVAE module in SLATE paper."""
+ conv2d = nn.Conv2d(64, output_dim, 1)
+ nn.init.xavier_uniform_(conv2d.weight)
+ nn.init.zeros_(conv2d.bias)
+ return nn.Sequential(
+ Conv2dBlockWithGroupNorm(vocab_size, 64, 1),
+ Conv2dBlockWithGroupNorm(64, 64, 3, 1, 1),
+ Conv2dBlockWithGroupNorm(64, 64, 1, 1),
+ Conv2dBlockWithGroupNorm(64, 64, 1, 1),
+ Conv2dBlockWithGroupNorm(64, 64 * 2 * 2, 1),
+ nn.PixelShuffle(2),
+ Conv2dBlockWithGroupNorm(64, 64, 3, 1, 1),
+ Conv2dBlockWithGroupNorm(64, 64, 1, 1),
+ Conv2dBlockWithGroupNorm(64, 64, 1, 1),
+ Conv2dBlockWithGroupNorm(64, 64 * 2 * 2, 1),
+ nn.PixelShuffle(2),
+ conv2d,
+ )
+
+
+def get_dvae_encoder(vocab_size: int, patch_size: int = 16, output_dim: int = 3):
+ """Get CNN decoder backbone for DVAE module in SLATE paper."""
+ conv2d = nn.Conv2d(64, vocab_size, 1)
+ nn.init.xavier_uniform_(conv2d.weight)
+ nn.init.zeros_(conv2d.bias)
+
+ return nn.Sequential(
+ Conv2dBlockWithGroupNorm(output_dim, 64, patch_size, patch_size),
+ Conv2dBlockWithGroupNorm(64, 64, 1, 1),
+ Conv2dBlockWithGroupNorm(64, 64, 1, 1),
+ Conv2dBlockWithGroupNorm(64, 64, 1, 1),
+ Conv2dBlockWithGroupNorm(64, 64, 1, 1),
+ Conv2dBlockWithGroupNorm(64, 64, 1, 1),
+ Conv2dBlockWithGroupNorm(64, 64, 1, 1),
+ conv2d,
+ )
+
+
+class StyleGANv2Decoder(nn.Module):
+ """CNN decoder as used in StyleGANv2 and GIRAFFE."""
+
+ def __init__(
+ self,
+ object_feature_dim: int,
+ output_dim: int = 4,
+ min_features=32,
+ input_size: int = 8,
+ output_size: int = 128,
+ activation_fn: str = "leaky_relu",
+ leaky_relu_slope: float = 0.2,
+ ):
+ super().__init__()
+ input_size_log2 = math.log2(input_size)
+ assert math.floor(input_size_log2) == input_size_log2, "Input size needs to be power of 2"
+
+ output_size_log2 = math.log2(output_size)
+ assert math.floor(output_size_log2) == output_size_log2, "Output size needs to be power of 2"
+
+ n_blocks = int(output_size_log2 - input_size_log2)
+
+ self.convs = nn.ModuleList()
+ cur_dim = object_feature_dim
+ for _ in range(n_blocks):
+ next_dim = max(cur_dim // 2, min_features)
+ self.convs.append(nn.Conv2d(cur_dim, next_dim, 3, stride=1, padding=1))
+ cur_dim = next_dim
+
+ self.skip_convs = nn.ModuleList()
+ cur_dim = object_feature_dim
+ for _ in range(n_blocks + 1):
+ self.skip_convs.append(nn.Conv2d(cur_dim, output_dim, 1, stride=1))
+ cur_dim = max(cur_dim // 2, min_features)
+
+ nn.init.zeros_(self.skip_convs[-1].bias)
+
+ if activation_fn == "relu":
+ self.activation_fn = nn.ReLU(inplace=True)
+ elif activation_fn == "leaky_relu":
+ self.activation_fn = nn.LeakyReLU(leaky_relu_slope, inplace=True)
+ else:
+ raise ValueError(f"Unknown activation function {activation_fn}")
+
+ def forward(self, inp):
+ output = self.skip_convs[0](inp)
+
+ features = inp
+ for conv, skip_conv in zip(self.convs, self.skip_convs[1:]):
+ features = F.interpolate(features, scale_factor=2, mode="nearest-exact")
+ features = conv(features)
+ features = self.activation_fn(features)
+
+ output = F.interpolate(
+ output, scale_factor=2, mode="bilinear", align_corners=False, antialias=True
+ )
+ output = output + skip_conv(features)
+
+ return output
+
+
+class SlotAttentionDecoder(nn.Module, RoutableMixin):
+ """Decoder used in the original slot attention paper."""
+
+ def __init__(
+ self,
+ decoder: nn.Module,
+ final_activation: Union[str, Callable] = "identity",
+ positional_embedding: Optional[nn.Module] = None,
+ object_features_path: Optional[str] = OBJECTS,
+ ):
+ nn.Module.__init__(self)
+ RoutableMixin.__init__(self, {"object_features": object_features_path})
+ self.initial_conv_size = (8, 8)
+ self.decoder = decoder
+ self.final_activation = get_activation_fn(final_activation)
+ self.positional_embedding = positional_embedding
+ if positional_embedding:
+ self.register_buffer("grid", build_grid_of_positions(self.initial_conv_size))
+
+ @RoutableMixin.route
+ def forward(self, object_features: torch.Tensor):
+ assert object_features.dim() >= 3 # Image or video data.
+ initial_shape = object_features.shape[:-1]
+ object_features = object_features.flatten(0, -2)
+
+ object_features = (
+ object_features.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, *self.initial_conv_size)
+ )
+ if self.positional_embedding:
+ object_features = self.positional_embedding(object_features, self.grid.unsqueeze(0))
+
+ # Apply deconvolution and restore object dimension.
+ output = self.decoder(object_features)
+ output = output.unflatten(0, initial_shape)
+
+ # Split out alpha channel and normalize over slots.
+ # The decoder is assumed to output tensors in CNN order, i.e. * x C x H x W.
+ rgb, alpha = output.split([3, 1], dim=-3)
+ rgb = self.final_activation(rgb)
+ alpha = alpha.softmax(dim=-4)
+
+ return ReconstructionOutput(
+ # Combine rgb weighted according to alpha channel.
+ reconstruction=(rgb * alpha).sum(-4),
+ object_reconstructions=rgb,
+ masks=alpha.squeeze(-3),
+ )
+
+
+class SlotAttentionDecoderGumbel(nn.Module, RoutableMixin):
+ """Decoder used in the original slot attention paper."""
+
+ def __init__(
+ self,
+ decoder: nn.Module,
+ final_activation: Union[str, Callable] = "identity",
+ positional_embedding: Optional[nn.Module] = None,
+ object_features_path: Optional[str] = OBJECTS,
+ left_mask_path: Optional[str] = None,
+ mask_type = "mask_normalized"
+ ):
+ nn.Module.__init__(self)
+ RoutableMixin.__init__(self, {"object_features": object_features_path, "left_mask":left_mask_path})
+ self.initial_conv_size = (8, 8)
+ self.decoder = decoder
+ self.final_activation = get_activation_fn(final_activation)
+ self.positional_embedding = positional_embedding
+ if positional_embedding:
+ self.register_buffer("grid", build_grid_of_positions(self.initial_conv_size))
+
+ self.mask_type = mask_type
+
+ @RoutableMixin.route
+ def forward(self, object_features: torch.Tensor,
+ left_mask = None):
+ assert object_features.dim() >= 3 # Image or video data.
+ initial_shape = object_features.shape[:-1]
+ object_features = object_features.flatten(0, -2)
+
+ object_features = (
+ object_features.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, *self.initial_conv_size)
+ )
+ if self.positional_embedding:
+ object_features = self.positional_embedding(object_features, self.grid.unsqueeze(0))
+
+ # Apply deconvolution and restore object dimension.
+ output = self.decoder(object_features)
+ output = output.unflatten(0, initial_shape)
+
+ # Split out alpha channel and normalize over slots.
+ # The decoder is assumed to output tensors in CNN order, i.e. * x C x H x W.
+ rgb, alpha = output.split([3, 1], dim=-3)
+ rgb = self.final_activation(rgb)
+ #B x K x C x H x W.
+ #alpha = alpha.softmax(dim=-4)
+
+ if self.mask_type == "logit":
+ VANISH = 1e5
+ drop_mask = 1 - left_mask # (b, s)
+ alpha = alpha - VANISH * drop_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) # (b, s, c, h, w)
+ alpha = alpha.softmax(dim=-4)
+ elif self.mask_type == "mask":
+ # drop_mask = 1 - left_mask # (b, s)
+ alpha = alpha.softmax(dim=-4) # (b, s, c, h, w)
+ alpha = alpha * left_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
+ elif self.mask_type == "mask_normalized":
+ # drop_mask = 1 - left_mask # (b, s)
+ MINOR = 1e-5
+ alpha = alpha.softmax(dim=-4) # (b, s, c, h, w)
+ alpha = alpha * left_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) # (b, s, c, h, w)
+ alpha = alpha/(alpha.sum(dim=-4, keepdim=True) + MINOR)
+ elif self.mask_type == "none":
+ alpha = alpha.softmax(dim=-4) # (b, s, c, h, w)
+
+ return ReconstructionOutput(
+ # Combine rgb weighted according to alpha channel.
+ reconstruction=(rgb * alpha).sum(-4),
+ object_reconstructions=rgb,
+ masks=alpha.squeeze(-3),
+ )
+
+
+
+class SlotAttentionAmodalDecoder(nn.Module, RoutableMixin):
+ """Decoder used in the original slot attention paper."""
+
+ def __init__(
+ self,
+ decoder: nn.Module,
+ final_activation: Union[str, Callable] = "identity",
+ positional_embedding: Optional[nn.Module] = None,
+ object_features_path: Optional[str] = OBJECTS,
+ ):
+ nn.Module.__init__(self)
+ RoutableMixin.__init__(self, {"object_features": object_features_path})
+ self.initial_conv_size = (8, 8)
+ self.decoder = decoder
+ self.final_activation = get_activation_fn(final_activation)
+ self.positional_embedding = positional_embedding
+ if positional_embedding:
+ self.register_buffer("grid", build_grid_of_positions(self.initial_conv_size))
+
+ def rescale_mask(self, mask):
+ max = torch.max(mask)
+ min = torch.min(mask)
+ mask_new = (mask - min) / (max - min)
+ return mask_new
+
+ @RoutableMixin.route
+ def forward(self, object_features: torch.Tensor):
+ assert object_features.dim() >= 3 # Image or video data.
+ initial_shape = object_features.shape[:-1]
+
+ object_features_ori = object_features.clone()
+
+ object_features = object_features.flatten(0, -2)
+ object_features = (
+ object_features.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, *self.initial_conv_size)
+ )
+ if self.positional_embedding:
+ object_features = self.positional_embedding(object_features, self.grid.unsqueeze(0))
+
+ # Apply deconvolution and restore object dimension.
+ output = self.decoder(object_features)
+ output = output.unflatten(0, initial_shape)
+
+ # Split out alpha channel and normalize over slots.
+ # The decoder is assumed to output tensors in CNN order, i.e. * x C x H x W.
+ rgb, alpha = output.split([3, 1], dim=-3)
+ rgb = self.final_activation(rgb)
+ alpha1 = alpha.softmax(dim=-4) # visible masks
+ alpha2 = alpha.sigmoid() # amodal masks
+
+ masks_vis = torch.zeros(alpha1.shape).to(alpha1.device)
+ for b in range(object_features_ori.shape[0]):
+ index = torch.sum(object_features_ori[b], dim=-1).nonzero(as_tuple=True)[0]
+ masks_vis[b][index] = alpha1[b][index]
+ for i in index:
+ masks_vis[b][i] = self.rescale_mask(alpha1[b][i])
+
+ return ReconstructionAmodalOutput(
+ # Combine rgb weighted according to alpha channel.
+ reconstruction=(rgb * alpha1).sum(-4),
+ object_reconstructions=rgb,
+ masks=alpha2.squeeze(-3),
+ masks_vis=alpha1.squeeze(-3),
+ masks_eval=masks_vis.squeeze(-3),
+ )
+
+
+class SlotAttentionOpticalFlowDecoder(nn.Module, RoutableMixin):
+ # TODO(flwenzel): for now use the same decoder as for rbg reconstruction. Might implement
+ # improved/specialized decoder.
+ # TODO(hornmax): Maybe we can merge this with the RGB decoder and generalize the task outputs.
+ # This is something for a later time though.
+
+ def __init__(
+ self,
+ decoder: nn.Module,
+ positional_embedding: Optional[nn.Module] = None,
+ object_features_path: Optional[str] = OBJECTS,
+ ):
+ nn.Module.__init__(self)
+ RoutableMixin.__init__(self, {"object_features": object_features_path})
+ self.initial_conv_size = (8, 8)
+ self.decoder = decoder
+ self.positional_embedding = positional_embedding
+ if positional_embedding:
+ self.register_buffer("grid", build_grid_of_positions(self.initial_conv_size))
+
+ @RoutableMixin.route
+ def forward(self, object_features: torch.Tensor):
+ assert object_features.dim() >= 3 # Image or video data.
+ initial_shape = object_features.shape[:-1]
+ object_features = object_features.flatten(0, -2)
+
+ object_features = (
+ object_features.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, *self.initial_conv_size)
+ )
+ if self.positional_embedding:
+ object_features = self.positional_embedding(object_features, self.grid.unsqueeze(0))
+
+ # Apply deconvolution and restore object dimension.
+ output = self.decoder(object_features)
+ output = output.unflatten(0, initial_shape)
+
+ # Split out alpha channel and normalize over slots.
+ # The decoder is assumed to output tensors in CNN order, i.e. * x C x H x W.
+ flow, alpha = output.split([2, 1], dim=2) # flow is assumed to be 2-dim.
+ alpha = alpha.softmax(dim=-4)
+
+ return OpticalFlowPredictionTaskOutput(
+ # Combine rgb weighted according to alpha channel.
+ predicted_flow=(flow * alpha).sum(-4),
+ object_flows=flow,
+ masks=alpha.squeeze(-3),
+ )
+
+
+class PatchDecoder(nn.Module, RoutableMixin):
+ """Decoder that takes object representations and reconstructs patches.
+
+ Args:
+ object_dim: Dimension of objects representations.
+ output_dim: Dimension of each patch.
+ num_patches: Number of patches P to reconstruct.
+ decoder: Function that returns backbone to use for decoding. Function takes input and output
+ dimensions and should return module that takes inputs of shape (B * K), P, N, and produce
+ outputs of shape (B * K), P, M, where K is the number of objects, N is the number of
+ input dimensions and M the number of output dimensions.
+ decoder_input_dim: Input dimension to decoder backbone. If specified, a linear
+ transformation from object to decoder dimension is added. If not specified, the object
+ dimension is used and no linear transform is added.
+ """
+
+ def __init__(
+ self,
+ object_dim: int,
+ output_dim: int,
+ num_patches: int,
+ decoder: Callable[[int, int], nn.Module],
+ decoder_input_dim: Optional[int] = None,
+ upsample_target: Optional[float] = None,
+ resize_mode: str = "bilinear",
+ object_features_path: Optional[str] = OBJECTS,
+ target_path: Optional[str] = None,
+ image_path: Optional[str] = None,
+ ):
+ nn.Module.__init__(self)
+ RoutableMixin.__init__(
+ self,
+ {"object_features": object_features_path, "target": target_path, "image": image_path},
+ )
+ self.output_dim = output_dim
+ self.num_patches = num_patches
+ self.upsample_target = upsample_target
+ self.resize_mode = resize_mode
+
+ if decoder_input_dim is not None:
+ self.inp_transform = nn.Linear(object_dim, decoder_input_dim, bias=True)
+ nn.init.xavier_uniform_(self.inp_transform.weight)
+ nn.init.zeros_(self.inp_transform.bias)
+ else:
+ self.inp_transform = None
+ decoder_input_dim = object_dim
+
+ self.decoder = decoder(decoder_input_dim, output_dim + 1)
+ self.pos_embed = nn.Parameter(torch.randn(1, num_patches, decoder_input_dim) * 0.02)
+
+ @RoutableMixin.route
+ def forward(
+ self,
+ object_features: torch.Tensor,
+ target: Optional[torch.Tensor] = None,
+ image: Optional[torch.Tensor] = None,
+ ):
+ assert object_features.dim() >= 3 # Image or video data.
+ if self.upsample_target is not None and target is not None:
+ target = (
+ resize_patches_to_image(
+ target.detach().transpose(-2, -1),
+ scale_factor=self.upsample_target,
+ resize_mode=self.resize_mode,
+ )
+ .flatten(-2, -1)
+ .transpose(-2, -1)
+ )
+
+ initial_shape = object_features.shape[:-1]
+ object_features = object_features.flatten(0, -2)
+
+ if self.inp_transform is not None:
+ object_features = self.inp_transform(object_features)
+
+ object_features = object_features.unsqueeze(1).expand(-1, self.num_patches, -1)
+
+ # Simple learned additive embedding as in ViT
+ object_features = object_features + self.pos_embed
+
+ output = self.decoder(object_features)
+ output = output.unflatten(0, initial_shape)
+
+ # Split out alpha channel and normalize over slots.
+ decoded_patches, alpha = output.split([self.output_dim, 1], dim=-1)
+ alpha = alpha.softmax(dim=-3)
+
+ reconstruction = torch.sum(decoded_patches * alpha, dim=-3)
+ masks = alpha.squeeze(-1)
+
+ if image is not None:
+ masks_as_image = resize_patches_to_image(
+ masks, size=image.shape[-1], resize_mode="bilinear"
+ )
+ else:
+ masks_as_image = None
+
+ return PatchReconstructionOutput(
+ reconstruction=reconstruction,
+ masks=alpha.squeeze(-1),
+ masks_as_image=masks_as_image,
+ target=target if target is not None else None,
+ )
+
+class PatchDecoderGumbelV1(nn.Module, RoutableMixin):
+ """Decoder that takes object representations and reconstructs patches.
+
+ Args:
+ object_dim: Dimension of objects representations.
+ output_dim: Dimension of each patch.
+ num_patches: Number of patches P to reconstruct.
+ decoder: Function that returns backbone to use for decoding. Function takes input and output
+ dimensions and should return module that takes inputs of shape (B * K), P, N, and produce
+ outputs of shape (B * K), P, M, where K is the number of objects, N is the number of
+ input dimensions and M the number of output dimensions.
+ decoder_input_dim: Input dimension to decoder backbone. If specified, a linear
+ transformation from object to decoder dimension is added. If not specified, the object
+ dimension is used and no linear transform is added.
+ """
+
+ def __init__(
+ self,
+ object_dim: int,
+ output_dim: int,
+ num_patches: int,
+ decoder: Callable[[int, int], nn.Module],
+ decoder_input_dim: Optional[int] = None,
+ upsample_target: Optional[float] = None,
+ resize_mode: str = "bilinear",
+ object_features_path: Optional[str] = OBJECTS,
+ target_path: Optional[str] = None,
+ image_path: Optional[str] = None,
+ left_mask_path: Optional[str] = None,
+ mask_type = "mask_normalized"
+ ):
+ nn.Module.__init__(self)
+ RoutableMixin.__init__(
+ self,
+ {"object_features": object_features_path, "target": target_path, "image": image_path,
+ "left_mask":left_mask_path},
+ )
+ self.output_dim = output_dim
+ self.num_patches = num_patches
+ self.upsample_target = upsample_target
+ self.resize_mode = resize_mode
+
+ if decoder_input_dim is not None:
+ self.inp_transform = nn.Linear(object_dim, decoder_input_dim, bias=True)
+ nn.init.xavier_uniform_(self.inp_transform.weight)
+ nn.init.zeros_(self.inp_transform.bias)
+ else:
+ self.inp_transform = None
+ decoder_input_dim = object_dim
+
+ self.decoder = decoder(decoder_input_dim, output_dim + 1)
+ self.pos_embed = nn.Parameter(torch.randn(1, num_patches, decoder_input_dim) * 0.02)
+ self.mask_type = mask_type
+
+ @RoutableMixin.route
+ def forward(
+ self,
+ object_features: torch.Tensor,
+ target: Optional[torch.Tensor] = None,
+ image: Optional[torch.Tensor] = None,
+ left_mask = None
+ ):
+ assert object_features.dim() >= 3 # Image or video data.
+ if self.upsample_target is not None and target is not None:
+ target = (
+ resize_patches_to_image(
+ target.detach().transpose(-2, -1),
+ scale_factor=self.upsample_target,
+ resize_mode=self.resize_mode,
+ )
+ .flatten(-2, -1)
+ .transpose(-2, -1)
+ )
+
+ initial_shape = object_features.shape[:-1]
+ object_features = object_features.flatten(0, -2) # (b*s, d)
+
+ if self.inp_transform is not None:
+ object_features = self.inp_transform(object_features)
+
+ object_features = object_features.unsqueeze(1).expand(-1, self.num_patches, -1) # (b*s, n, d)
+
+ # Simple learned additive embedding as in ViT
+ object_features = object_features + self.pos_embed
+
+ output = self.decoder(object_features) # (b*s, n, d + 1)
+ output = output.unflatten(0, initial_shape) # (b, s, n, d + 1)
+
+ # Split out alpha channel and normalize over slots.
+ decoded_patches, alpha = output.split([self.output_dim, 1], dim=-1)
+ # (b, s, n, d), (b, s, n, 1)
+ if self.mask_type == "logit":
+ VANISH = 1e5
+ drop_mask = 1 - left_mask # (b, s)
+ alpha = alpha - VANISH * drop_mask.unsqueeze(-1).unsqueeze(-1) # (b, s, n, 1)
+ alpha = alpha.softmax(dim=-3)
+ elif self.mask_type == "mask":
+ # drop_mask = 1 - left_mask # (b, s)
+ alpha = alpha.softmax(dim=-3) # (b, s, n, 1)
+ alpha = alpha * left_mask.unsqueeze(-1).unsqueeze(-1)
+ elif self.mask_type == "mask_normalized":
+ # drop_mask = 1 - left_mask # (b, s)
+ MINOR = 1e-5
+ alpha = alpha.softmax(dim=-3) # (b, s, n, 1)
+ alpha = alpha * left_mask.unsqueeze(-1).unsqueeze(-1) # (b, s, n, 1)
+ alpha = alpha/(alpha.sum(dim=-3, keepdim=True) + MINOR)
+ elif self.mask_type == "none":
+ alpha = alpha.softmax(dim=-3) # (b, s, n, 1)
+
+ reconstruction = torch.sum(decoded_patches * alpha, dim=-3)
+ masks = alpha.squeeze(-1)
+
+ if image is not None:
+ masks_as_image = resize_patches_to_image(
+ masks, size=image.shape[-1], resize_mode="bilinear"
+ )
+ else:
+ masks_as_image = None
+
+ return PatchReconstructionOutput(
+ reconstruction=reconstruction,
+ masks=alpha.squeeze(-1),
+ masks_as_image=masks_as_image,
+ target=target if target is not None else None,
+ )
+
+
+
+class AutoregressivePatchDecoder(nn.Module, RoutableMixin):
+ """Decoder that takes object representations and reconstructs patches autoregressively.
+
+ Args:
+ object_dim: Dimension of objects representations.
+ output_dim: Dimension of each patch.
+ num_patches: Number of patches P to reconstruct.
+ decoder: Function that returns backbone to use for decoding. Function takes input and output
+ dimensions and should return module that takes autoregressive targets of shape B, P, M,
+ conditioning of shape B, K, N, masks of shape P, P, and produces outputs of shape
+ B, P, M, where K is the number of objects, N is the number of input dimensions and M the
+ number of output dimensions.
+ decoder_cond_dim: Dimension of conditioning input of decoder backbone. If specified, a linear
+ transformation from object to decoder dimension is added. If not specified, the object
+ dimension is used and no linear transform is added.
+ """
+
+ def __init__(
+ self,
+ object_dim: int,
+ output_dim: int,
+ num_patches: int,
+ decoder: Callable[[int, int], nn.Module],
+ decoder_dim: Optional[int] = None,
+ decoder_cond_dim: Optional[int] = None,
+ upsample_target: Optional[float] = None,
+ resize_mode: str = "bilinear",
+ use_decoder_masks: bool = False,
+ use_bos_token: bool = True,
+ use_input_transform: bool = False,
+ use_input_norm: bool = False,
+ use_output_transform: bool = False,
+ use_positional_embedding: bool = False,
+ object_features_path: Optional[str] = OBJECTS,
+ masks_path: Optional[str] = "perceptual_grouping.masks",
+ target_path: Optional[str] = None,
+ image_path: Optional[str] = None,
+ empty_objects_path: Optional[str] = None,
+ ):
+ nn.Module.__init__(self)
+ RoutableMixin.__init__(
+ self,
+ {
+ "object_features": object_features_path,
+ "target": target_path,
+ "image": image_path,
+ "masks": masks_path,
+ "empty_objects": empty_objects_path,
+ },
+ )
+ self.output_dim = output_dim
+ self.num_patches = num_patches
+ self.upsample_target = upsample_target
+ self.resize_mode = resize_mode
+ self.use_decoder_masks = use_decoder_masks
+
+ if decoder_dim is None:
+ decoder_dim = output_dim
+
+ self.decoder = decoder(decoder_dim, decoder_dim)
+ if use_bos_token:
+ self.bos_token = nn.Parameter(torch.randn(1, 1, output_dim) * output_dim**-0.5)
+ else:
+ self.bos_token = None
+ if decoder_cond_dim is not None:
+ self.cond_transform = nn.Sequential(
+ nn.Linear(object_dim, decoder_cond_dim, bias=False),
+ nn.LayerNorm(decoder_cond_dim, eps=1e-5),
+ )
+ nn.init.xavier_uniform_(self.cond_transform[0].weight)
+ else:
+ decoder_cond_dim = object_dim
+ self.cond_transform = nn.LayerNorm(decoder_cond_dim, eps=1e-5)
+
+ if use_input_transform:
+ self.inp_transform = nn.Sequential(
+ nn.Linear(output_dim, decoder_dim, bias=False),
+ nn.LayerNorm(decoder_dim, eps=1e-5),
+ )
+ nn.init.xavier_uniform_(self.inp_transform[0].weight)
+ elif use_input_norm:
+ self.inp_transform = nn.LayerNorm(decoder_dim, eps=1e-5)
+ else:
+ self.inp_transform = None
+
+ if use_output_transform:
+ self.outp_transform = nn.Linear(decoder_dim, output_dim)
+ nn.init.xavier_uniform_(self.outp_transform.weight)
+ nn.init.zeros_(self.outp_transform.bias)
+ else:
+ self.outp_transform = None
+
+ if use_positional_embedding:
+ self.pos_embed = nn.Parameter(
+ torch.randn(1, num_patches, decoder_dim) * decoder_dim**-0.5
+ )
+ else:
+ self.pos_embed = None
+
+ mask = torch.triu(torch.full((num_patches, num_patches), float("-inf")), diagonal=1)
+ self.register_buffer("mask", mask)
+
+ @RoutableMixin.route
+ def forward(
+ self,
+ object_features: torch.Tensor,
+ masks: torch.Tensor,
+ target: torch.Tensor,
+ image: Optional[torch.Tensor] = None,
+ empty_objects: Optional[torch.Tensor] = None,
+ ) -> PatchReconstructionOutput:
+ assert object_features.dim() >= 3 # Image or video data.
+ if self.upsample_target is not None and target is not None:
+ target = (
+ resize_patches_to_image(
+ target.detach().transpose(-2, -1),
+ scale_factor=self.upsample_target,
+ resize_mode=self.resize_mode,
+ )
+ .flatten(-2, -1)
+ .transpose(-2, -1)
+ )
+ # Squeeze frames into batch if present.
+ object_features = object_features.flatten(0, -3)
+
+ object_features = self.cond_transform(object_features)
+
+ # Squeeze frame into batch size if necessary.
+ initial_targets_shape = target.shape[:-2]
+ targets = target.flatten(0, -3)
+ if self.bos_token is not None:
+ bs = len(object_features)
+ inputs = torch.cat((self.bos_token.expand(bs, -1, -1), targets[:, :-1].detach()), dim=1)
+ else:
+ inputs = targets
+
+ if self.inp_transform is not None:
+ inputs = self.inp_transform(inputs)
+
+ if self.pos_embed is not None:
+ # Simple learned additive embedding as in ViT
+ inputs = inputs + self.pos_embed
+
+ if empty_objects is not None:
+ outputs = self.decoder(
+ inputs, object_features, self.mask, memory_key_padding_mask=empty_objects
+ )
+ else:
+ outputs = self.decoder(inputs, object_features, self.mask)
+
+ if self.use_decoder_masks:
+ decoded_patches, masks = outputs
+ else:
+ decoded_patches = outputs
+
+ if self.outp_transform is not None:
+ decoded_patches = self.outp_transform(decoded_patches)
+
+ decoded_patches = decoded_patches.unflatten(0, initial_targets_shape)
+
+ if image is not None:
+ masks_as_image = resize_patches_to_image(
+ masks, size=image.shape[-1], resize_mode="bilinear"
+ )
+ else:
+ masks_as_image = None
+
+ return PatchReconstructionOutput(
+ reconstruction=decoded_patches, masks=masks, masks_as_image=masks_as_image, target=target
+ )
+
+
+class DensityPredictingSlotAttentionDecoder(nn.Module, RoutableMixin):
+ """Decoder predicting color and densities along a ray into the scene."""
+
+ def __init__(
+ self,
+ object_dim: int,
+ decoder: nn.Module,
+ depth_positions: int,
+ white_background: bool = False,
+ normalize_densities_along_slots: bool = False,
+ initial_alpha: float = None,
+ object_features_path: Optional[str] = OBJECTS,
+ ):
+ nn.Module.__init__(self)
+ RoutableMixin.__init__(self, {"object_features": object_features_path})
+ self.initial_conv_size = (8, 8)
+ self.depth_positions = depth_positions
+ self.white_background = white_background
+ self.normalize_densities_along_slots = normalize_densities_along_slots
+ self.register_buffer("grid", build_grid_of_positions(self.initial_conv_size))
+ self.pos_embedding = SoftPositionEmbed(2, object_dim, cnn_channel_order=True)
+
+ self.decoder = decoder
+ if isinstance(self.decoder, nn.Sequential) and hasattr(self.decoder[-1], "bias"):
+ nn.init.zeros_(self.decoder[-1].bias)
+
+ if initial_alpha is not None:
+ # Distance between neighboring ray points, currently assumed to be 1
+ point_distance = 1
+ # Value added to density output of network before softplus activation. If network outputs
+ # are approximately zero, the initial mask value per voxel becomes `initial_alpha`. See
+ # https://arxiv.org/abs/2111.11215 for a derivation.
+ self.initial_density_offset = math.log((1 - initial_alpha) ** (-1 / point_distance) - 1)
+ else:
+ self.initial_density_offset = 0.0
+
+ def _render_objectwise(self, densities, rgbs):
+ """Render objects individually.
+
+ Args:
+ densities: Predicted densities of shape (B, S, Z, H, W), where S is the number of slots
+ and Z is the number of depth positions.
+ rgbs: Predicted color values of shape (B, S, 3, H, W), where S is the number of slots.
+ background: Optional background to render on.
+ """
+ densities_objectwise = densities.flatten(0, 1).unsqueeze(2)
+ rgbs_objectwise = rgbs.flatten(0, 1).unsqueeze(1)
+ rgbs_objectwise = rgbs_objectwise.expand(-1, densities_objectwise.shape[1], -1, -1, -1)
+
+ if self.white_background:
+ background = torch.full_like(rgbs_objectwise[:, 0], 1.0) # White color, i.e. 0xFFFFFF
+ else:
+ background = None
+
+ object_reconstructions, _, object_masks_per_depth, p_ray_hits_points = volume_rendering(
+ densities_objectwise, rgbs_objectwise, background=background
+ )
+
+ object_reconstructions = object_reconstructions.unflatten(0, rgbs.shape[:2])
+ object_masks_per_depth = object_masks_per_depth.squeeze(2).unflatten(0, rgbs.shape[:2])
+ p_ray_hits_points = p_ray_hits_points.squeeze(2).unflatten(0, rgbs.shape[:2])
+
+ p_ray_hits_points_and_reflects = p_ray_hits_points * object_masks_per_depth
+ object_masks, object_depth_map = p_ray_hits_points_and_reflects.max(2)
+
+ return object_reconstructions, object_masks, object_depth_map
+
+ @RoutableMixin.route
+ def forward(self, object_features: torch.Tensor):
+ # TODO(hornmax): Adapt this for video data.
+ # Reshape object dimension into batch dimension and broadcast.
+ bs, n_objects, object_feature_dim = object_features.shape
+ object_features = object_features.view(bs * n_objects, object_feature_dim, 1, 1).expand(
+ -1, -1, *self.initial_conv_size
+ )
+ object_features = self.pos_embedding(object_features, self.grid.unsqueeze(0))
+
+ # Apply deconvolution and restore object dimension.
+ output = self.decoder(object_features)
+ output = output.view(bs, n_objects, *output.shape[-3:])
+
+ # Split rgb and density channels and transform to appropriate ranges.
+ rgbs, densities = output.split([3, self.depth_positions], dim=2)
+ rgbs = torch.sigmoid(rgbs) # B x S x 3 x H x W
+ densities = F.softplus(densities + self.initial_density_offset) # B x S x Z x H x W
+
+ if self.normalize_densities_along_slots:
+ densities_depthwise_sum = torch.einsum("bszhw -> bzhw", densities).unsqueeze(1)
+ densities_weighted = densities * F.softmax(densities, dim=1)
+ densities_weighted_sum = torch.einsum("bszhw -> bzhw", densities_weighted).unsqueeze(1)
+ densities = densities_weighted * densities_depthwise_sum / densities_weighted_sum
+
+ # Combine densities from different slots by summing over slot dimension
+ density = torch.einsum("bszhw -> bzhw", densities).unsqueeze(2)
+ # Combine colors from different slots by density-weighted mean
+ rgb = torch.einsum("bszhw, bschw -> bzchw", densities, rgbs) / density
+
+ if self.white_background:
+ background = torch.full_like(rgb[:, 0], 1.0) # White color, i.e. 0xFFFFFF
+ else:
+ background = None
+
+ reconstruction, _, _, p_ray_hits_point = volume_rendering(
+ density, rgb, background=background
+ )
+
+ if self.training:
+ # Get object masks by taking the max density over all depth positions
+ masks = 1 - torch.exp(-densities.detach().max(dim=2).values)
+ object_reconstructions = rgbs.detach() * masks.unsqueeze(2)
+
+ if background is not None:
+ masks = torch.cat((masks, p_ray_hits_point[:, -1:, 0]), dim=1)
+ object_reconstructions = torch.cat(
+ (object_reconstructions, background[:, None]), dim=1
+ )
+
+ return ReconstructionOutput(
+ reconstruction=reconstruction,
+ object_reconstructions=object_reconstructions,
+ masks=masks,
+ )
+ else:
+ object_reconstructions, object_masks, object_depth_map = self._render_objectwise(
+ densities, rgbs
+ )
+
+ # Joint depth map results from taking minimum depth over objects per pixel, whereas
+ # joint mask results from the index of the object with minimum depth
+ depth_map, mask_dense = object_depth_map.min(1)
+
+ if background is not None:
+ object_reconstructions = torch.cat(
+ (object_reconstructions, background[:, None]), dim=1
+ )
+ # Assign designated background class wherever the depth map indicates background
+ mask_dense[depth_map == self.depth_positions] = n_objects
+ n_classes = n_objects + 1
+ else:
+ n_classes = n_objects
+
+ masks = F.one_hot(mask_dense, num_classes=n_classes)
+ masks = masks.squeeze(1).permute(0, 3, 1, 2).contiguous() # B x C x H x W
+
+ return DepthReconstructionOutput(
+ reconstruction=reconstruction,
+ object_reconstructions=object_reconstructions,
+ masks=masks,
+ masks_amodal=object_masks,
+ depth_map=depth_map,
+ object_depth_map=object_depth_map,
+ densities=densities,
+ colors=rgbs.unsqueeze(2).expand(-1, -1, self.depth_positions, -1, -1, -1),
+ )
+
+
+def volume_rendering(
+ densities: torch.Tensor,
+ colors: torch.Tensor,
+ distances: Union[float, torch.Tensor] = None,
+ background: torch.Tensor = None,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Volume render along camera rays (also known as alpha compositing).
+
+ For each ray, assumes input of Z density and C color channels, corresponding to Z points along
+ the ray from front to back of the scene.
+
+ Args:
+ densities: Tensor of shape (B, Z, 1, ...). Non-negative, real valued density values along
+ the ray.
+ colors: Tensor of shape (B, Z, C, ...). Color values along the ray.
+ distances: Tensor of shape (B, Z, 1, ...). Optional distances between this ray point and
+ the next. Can also be a single float value. If not given, distances between all points
+ are assumed to be one. The last value corresponds to the distance between the last point
+ and the background.
+ background: Tensor of shape (B, C, ...). An optional background image that the rendering can
+ be put on.
+
+ Returns:
+ Tuple of tensors of shape (B, C, ...), (B, Z, C, ...), (B, Z, 1, ...), (B, Z, 1, ...).
+ First tensor is the rendered image, second tensor are the rendered images along different
+ points of the ray, third tensor the alpha masks for each point of the ray, fourth tensor the
+ probabilities of reaching each point of the ray (the transmittances). If background is not
+ None, the background is included as the last ray point.
+ """
+ if distances is None:
+ transmittances = torch.exp(-torch.cumsum(densities, dim=1))
+ p_ray_reflects = 1.0 - torch.exp(-densities)
+ else:
+ densities_distance_weighted = densities * distances
+ transmittances = torch.exp(-torch.cumsum(densities_distance_weighted, dim=1))
+ p_ray_reflects = 1.0 - torch.exp(-densities_distance_weighted)
+
+ # First object has 100% probability of being hit as it cannot be occluded by other objects
+ p_ray_hits_point = torch.cat((torch.ones_like(densities[:, :1]), transmittances), dim=1)
+
+ if background is not None:
+ background = background.unsqueeze(1)
+
+ # All rays reaching the background reflect
+ p_ray_reflects = torch.cat((p_ray_reflects, torch.ones_like(p_ray_reflects[:, :1])), dim=1)
+ colors = torch.cat((colors, background), dim=1)
+ else:
+ p_ray_hits_point = p_ray_hits_point[:, :-1]
+
+ z_images = p_ray_reflects * colors
+ image = (p_ray_hits_point * z_images).sum(dim=1)
+
+ return image, z_images, p_ray_reflects, p_ray_hits_point
+
+
+class DVAEDecoder(nn.Module, RoutableMixin):
+ """VQ Decoder used in the original SLATE paper."""
+
+ def __init__(
+ self,
+ decoder: nn.Module,
+ patch_size: int = 4,
+ features_path: Optional[str] = None,
+ ):
+ nn.Module.__init__(self)
+ RoutableMixin.__init__(self, {"features": features_path})
+ self.initial_conv_size = (patch_size, patch_size)
+ self.decoder = decoder
+
+ @RoutableMixin.route
+ def forward(self, features: Dict[str, torch.Tensor]):
+ rgb = self.decoder(features)
+ return SimpleReconstructionOutput(reconstruction=rgb)
+
+
+class BBoxAndClsDecoder(nn.Module, RoutableMixin):
+ """Decoder used for multiple object tracking."""
+
+ def __init__(
+ self,
+ object_dim: int,
+ hidden_dim: int,
+ num_layers: int,
+ num_classes: int,
+ img_w: int,
+ img_h: int,
+ score_thresh: float = 0.5,
+ filter_score_thresh: float = 0.3,
+ miss_tolerance: float = 5,
+ object_features_path: Optional[str] = OBJECTS,
+ ):
+ nn.Module.__init__(self)
+ RoutableMixin.__init__(self, {"object_features": object_features_path})
+ self.num_layers = num_layers
+ self.num_classes = num_classes
+ self.img_w = img_w
+ self.img_h = img_h
+ self.score_thresh = score_thresh
+ self.filter_score_thresh = filter_score_thresh
+ self.miss_tolerance = miss_tolerance
+ h = [hidden_dim] * (num_layers - 1)
+ self.bbox_layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([object_dim] + h, h + [4]))
+
+ self.cls_layers = nn.ModuleList(
+ nn.Linear(n, k) for n, k in zip([object_dim] + h, h + [num_classes])
+ )
+
+ def update(self, track_instances: Instances):
+ track_instances.disappear_time[track_instances.scores >= self.score_thresh] = 0
+ for i in range(len(track_instances)):
+ if track_instances.obj_idxes[i] == -1 and track_instances.scores[i] >= self.score_thresh:
+ track_instances.obj_idxes[i] = self.max_obj_id
+ self.max_obj_id += 1
+ elif (
+ track_instances.obj_idxes[i] >= 0
+ and track_instances.scores[i] < self.filter_score_thresh
+ ):
+ track_instances.disappear_time[i] += 1
+ # When using SAVi structure where no new det query,
+ # we do not kill tracks.
+ """
+ if track_instances.disappear_time[i] >= self.miss_tolerance:
+ # Set the obj_id to -1.
+ # Then this track will be removed by TrackEmbeddingLayer.
+ track_instances.obj_idxes[i] = -1
+ """
+
+ # Note, this is duplicated with loss.
+ def _generate_empty_tracks(self, num_queries, device):
+ track_instances = Instances((1, 1))
+
+ # At init, the number of track_instances is the same as slot number
+ track_instances.obj_idxes = torch.full((num_queries,), -1, dtype=torch.long, device=device)
+ track_instances.matched_gt_idxes = torch.full(
+ (num_queries,), -1, dtype=torch.long, device=device
+ )
+ track_instances.disappear_time = torch.zeros((num_queries,), dtype=torch.long, device=device)
+ track_instances.iou = torch.zeros((num_queries,), dtype=torch.float, device=device)
+ track_instances.scores = torch.zeros((num_queries,), dtype=torch.float, device=device)
+ track_instances.track_scores = torch.zeros((num_queries,), dtype=torch.float, device=device)
+ track_instances.pred_boxes = torch.zeros((num_queries, 4), dtype=torch.float, device=device)
+ track_instances.pred_logits = torch.zeros(
+ (num_queries, self.num_classes), dtype=torch.float, device=device
+ )
+
+ return track_instances.to(device)
+
+ @RoutableMixin.route
+ def forward(self, object_features: torch.Tensor):
+ assert object_features.dim() >= 3 # Image or video data.
+ bbox_x = object_features
+ cls_x = object_features
+
+ for i, layer in enumerate(self.bbox_layers):
+ bbox_x = F.relu(layer(bbox_x)) if i < self.num_layers - 1 else layer(bbox_x)
+ for i, layer in enumerate(self.cls_layers):
+ cls_x = F.relu(layer(cls_x)) if i < self.num_layers - 1 else layer(cls_x)
+ # call softmax on the objectness or cls
+ cls_x = F.sigmoid(cls_x)
+ # make sure bbox prediction are positive.
+ # since we normalize the box to [0, 1]
+ bbox_x = F.sigmoid(bbox_x)
+
+ # Postprocessing
+ # cls_x to class idx
+ scores, labels = cls_x.max(-1)
+ # converting bbox_x to ori image space
+ ori_res_bbox_x = box_cxcywh_to_xyxy(bbox_x)
+ ori_res_bbox_x[:, :, :, 0::2] = ori_res_bbox_x[:, :, :, 0::2] * self.img_w
+ ori_res_bbox_x[:, :, :, 1::2] = ori_res_bbox_x[:, :, :, 1::2] * self.img_h
+
+ batch_size, num_frames, num_queries, _ = object_features.shape
+ device = bbox_x.device
+
+ inference_obj_idxes = torch.full(
+ (
+ batch_size,
+ num_frames,
+ num_queries,
+ ),
+ -1,
+ dtype=torch.long,
+ device=device,
+ )
+
+ for cidx in range(batch_size):
+ # Init empty prediction tracks
+ track_instances = self._generate_empty_tracks(num_queries, device)
+ self.max_obj_id = 0
+ for fidx in range(num_frames):
+ track_instances.scores = scores[cidx, fidx]
+ track_instances.pred_boxes = ori_res_bbox_x[cidx, fidx]
+ self.update(track_instances)
+ inference_obj_idxes[cidx, fidx] = track_instances.obj_idxes
+
+ return BBoxOutput(
+ bboxes=bbox_x,
+ classes=cls_x,
+ ori_res_bboxes=ori_res_bbox_x,
+ inference_obj_idxes=inference_obj_idxes,
+ )
diff --git a/ocl/distillation.py b/ocl/distillation.py
new file mode 100644
index 0000000..2e1b610
--- /dev/null
+++ b/ocl/distillation.py
@@ -0,0 +1,95 @@
+import copy
+from typing import Any, Dict, List, Optional, Union
+
+import torch
+from torch import nn
+
+from ocl import scheduling
+from ocl.utils.routing import Combined
+from ocl.utils.trees import get_tree_element
+
+
+class EMASelfDistillation(nn.Module):
+ def __init__(
+ self,
+ student: Union[nn.Module, Dict[str, nn.Module]],
+ schedule: scheduling.HPScheduler,
+ student_remapping: Optional[Dict[str, str]] = None,
+ teacher_remapping: Optional[Dict[str, str]] = None,
+ ):
+ super().__init__()
+ # Do this for convenience to reduce crazy amount of nesting.
+ if isinstance(student, dict):
+ student = Combined(student)
+ if student_remapping is None:
+ student_remapping = {}
+ if teacher_remapping is None:
+ teacher_remapping = {}
+
+ self.student = student
+ self.teacher = copy.deepcopy(student)
+ self.schedule = schedule
+ self.student_remapping = {key: value.split(".") for key, value in student_remapping.items()}
+ self.teacher_remapping = {key: value.split(".") for key, value in teacher_remapping.items()}
+
+ def build_input_dict(self, inputs, remapping):
+ if not remapping:
+ return inputs
+ # This allows us to bing the initial input and previous_output into a similar format.
+ output_dict = {}
+ for output_path, input_path in remapping.items():
+ source = get_tree_element(inputs, input_path)
+
+ output_path = output_path.split(".")
+ cur_search = output_dict
+ for path_part in output_path[:-1]:
+ # Iterate along path and create nodes that do not exist yet.
+ try:
+ # Get element prior to last.
+ cur_search = get_tree_element(cur_search, [path_part])
+ except ValueError:
+ # Element does not yet exist.
+ cur_search[path_part] = {}
+ cur_search = cur_search[path_part]
+
+ cur_search[output_path[-1]] = source
+ return output_dict
+
+ def forward(self, inputs: Dict[str, Any]):
+ if self.training:
+ with torch.no_grad():
+ m = self.schedule(inputs["global_step"]) # momentum parameter
+ for param_q, param_k in zip(self.student.parameters(), self.teacher.parameters()):
+ param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
+
+ # prefix variable similar to combined module.
+ prefix: List[str]
+ if "prefix" in inputs.keys():
+ prefix = inputs["prefix"]
+ else:
+ prefix = []
+ inputs["prefix"] = prefix
+
+ outputs = get_tree_element(inputs, prefix)
+
+ # Forward pass student.
+ prefix.append("student")
+ outputs["student"] = {}
+ student_inputs = self.build_input_dict(inputs, self.student_remapping)
+ outputs["student"] = self.student(inputs={**inputs, **student_inputs})
+ # Teacher and student share the same code, thus paths also need to be the same. To ensure
+ # that we save the student outputs and run the teacher as if it where the student.
+ student_output = outputs["student"]
+
+ # Forward pass teacher, but pretending to be student.
+ outputs["student"] = {}
+ teacher_inputs = self.build_input_dict(inputs, self.teacher_remapping)
+
+ with torch.no_grad():
+ outputs["teacher"] = self.teacher(inputs={**inputs, **teacher_inputs})
+ prefix.pop()
+
+ # Set correct outputs again.
+ outputs["student"] = student_output
+
+ return outputs
diff --git a/ocl/feature_extractors.py b/ocl/feature_extractors.py
new file mode 100644
index 0000000..68827c4
--- /dev/null
+++ b/ocl/feature_extractors.py
@@ -0,0 +1,1006 @@
+"""Implementation of feature extractors."""
+import enum
+import itertools
+import math
+from functools import partial
+from typing import Callable, List, Optional, Union
+
+import torch
+from torch import nn
+
+from ocl import base, path_defaults
+from ocl.utils.routing import RoutableMixin
+
+
+def cnn_compute_positions_and_flatten(features: torch.Tensor):
+ """Flatten output image CNN output and return it with positions of the features."""
+ # todo(hornmax): see how this works with vision transformer based architectures.
+ spatial_dims = features.shape[2:]
+ positions = torch.cartesian_prod(
+ *[torch.linspace(0.0, 1.0, steps=dim, device=features.device) for dim in spatial_dims]
+ )
+ # reorder into format (batch_size, flattened_spatial_dims, feature_dim).
+ flattened = torch.permute(features.view(features.shape[:2] + (-1,)), (0, 2, 1)).contiguous()
+ return positions, flattened
+
+
+def transformer_compute_positions(features: torch.Tensor):
+ """Compute positions for Transformer features."""
+ n_tokens = features.shape[1]
+ image_size = math.sqrt(n_tokens)
+ image_size_int = int(image_size)
+ assert (
+ image_size_int == image_size
+ ), "Position computation for Transformers requires square image"
+
+ spatial_dims = (image_size_int, image_size_int)
+ positions = torch.cartesian_prod(
+ *[torch.linspace(0.0, 1.0, steps=dim, device=features.device) for dim in spatial_dims]
+ )
+
+ return positions
+
+
+class ImageFeatureExtractor(base.FeatureExtractor, RoutableMixin):
+ """Feature extractor which operates on images.
+
+ For these we reshape the frame dimension into the batch dimension and process the frames as
+ individual images.
+ """
+
+ def __init__(self, video_path: Optional[str] = path_defaults.VIDEO):
+ base.FeatureExtractor.__init__(self)
+ RoutableMixin.__init__(self, {"video": video_path, "global_step": path_defaults.GLOBAL_STEP})
+
+ def forward_images(self, images: torch.Tensor):
+ pass
+
+ @RoutableMixin.route
+ def forward(self, video: torch.Tensor) -> base.FeatureExtractorOutput:
+ # print("video.shape")
+ # print(video.shape)
+
+ ndim = video.dim()
+ assert ndim == 4 or ndim == 5
+
+ if ndim == 5:
+ # Handling video data.
+ bs, frames, channels, height, width = video.shape
+ images = video.view(bs * frames, channels, height, width).contiguous()
+ else:
+ images = video
+
+ result = self.forward_images(images)
+
+ if len(result) == 2:
+ positions, features = result
+ aux_features = None
+ elif len(result) == 3:
+ positions, features, aux_features = result
+
+ if ndim == 5:
+ features = features.unflatten(0, (bs, frames))
+ if aux_features is not None:
+ aux_features = {k: f.unflatten(0, (bs, frames)) for k, f in aux_features.items()}
+
+ return base.FeatureExtractorOutput(features, positions, aux_features)
+
+
+class ClipImageModel(nn.Module, RoutableMixin):
+ def __init__(
+ self,
+ model_type: str,
+ image_path: Optional[str] = path_defaults.VIDEO,
+ freeze_model: bool = False,
+ reset_weights: bool = False,
+ remove_pooling: bool = False,
+ ):
+ try:
+ import clip
+ except ImportError:
+ raise Exception("Using clip models requires installation with extra `clip`.")
+ nn.Module.__init__(self)
+ RoutableMixin.__init__(self, {"image": image_path})
+ self.freeze_model = freeze_model
+
+ self.clip_vision_model = clip.load(
+ model_type,
+ # Initially force cpu to ensure tensors are float32 (load routine automatically converts
+ # to half precision if GPUs are detected). We can still do half-precision training via
+ # pytorch lightning if we want to.
+ device="cpu",
+ )[0].visual
+ if self.freeze_model:
+ for parameter in self.clip_vision_model.parameters():
+ parameter.requires_grad_(False)
+
+ if reset_weights:
+
+ def weight_reset(module):
+ if hasattr(module, "reset_parameters"):
+ module.reset_parameters()
+
+ self.clip_vision_model.apply(weight_reset)
+ self.clip_vision_model.initialize_parameters()
+
+ if remove_pooling:
+ if isinstance(self.clip_vision_model, clip.model.VisionTransformer):
+ self.get_output = self._get_features_from_vision_transformer
+ else:
+ self.get_output = self._get_features_from_resnet
+ else:
+ self.get_output = self.clip_vision_model
+
+ def _get_features_from_vision_transformer(self, x):
+ # Commands from:
+ # https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L223
+ model = self.clip_vision_model
+
+ x = model.conv1(x) # shape = [*, width, grid, grid]
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
+ x = torch.cat(
+ [
+ model.class_embedding
+ + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
+ x,
+ ],
+ dim=1,
+ ) # shape = [*, grid ** 2 + 1, width]
+ x = x + model.positional_embedding
+ x = model.ln_pre(x)
+
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = model.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.ln_post(x)
+ return x
+
+ def _get_features_from_resnet(self, x):
+ # Commands from:
+ # https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L138
+
+ model = self.clip_vision_model
+ # Apply "stem".
+ x = model.relu1(model.bn1(model.conv1(x)))
+ x = model.relu2(model.bn2(model.conv2(x)))
+ x = model.relu3(model.bn3(model.conv3(x)))
+ x = model.avgpool(x)
+
+ x = model.layer1(x)
+ x = model.layer2(x)
+ x = model.layer3(x)
+ x = model.layer4(x)
+ return x
+
+ @RoutableMixin.route
+ def forward(self, image: torch.Tensor):
+ if self.freeze_model:
+ with torch.no_grad():
+ return self.get_output(image)
+ else:
+ return self.get_output(image)
+
+
+class ClipTextModel(nn.Module, RoutableMixin):
+ def __init__(
+ self,
+ model_type: str,
+ text_path: Optional[str] = path_defaults.TEXT,
+ freeze_model: bool = False,
+ reset_weights: bool = False,
+ remove_pooling: bool = False,
+ remove_eot: bool = False,
+ ):
+ try:
+ import clip
+ except ImportError:
+ raise Exception("Using clip models requires installation with extra `clip`.")
+ nn.Module.__init__(self)
+ RoutableMixin.__init__(self, {"text": text_path})
+ self.freeze_model = freeze_model
+ self.remove_pooling = remove_pooling
+
+ clip_model = clip.load(
+ model_type,
+ # Initially force cpu to ensure tensors are float32 (load routine automatically converts
+ # to half precision if GPUs are detected). We can still do half-precision training via
+ # pytorch lightning if we want to.
+ device="cpu",
+ )[0]
+ if reset_weights:
+
+ def weight_reset(module):
+ if hasattr(module, "reset_parameters"):
+ module.reset_parameters()
+
+ clip_model.apply(weight_reset)
+ clip_model.initialize_parameters()
+
+ self.token_embedding = clip_model.token_embedding
+ self.positional_embedding = clip_model.positional_embedding
+ self.transformer = clip_model.transformer
+ self.ln_final = clip_model.ln_final
+ self.text_projection = clip_model.text_projection
+
+ if self.freeze_model:
+ for parameter in self.parameters():
+ parameter.requires_grad_(False)
+
+ self.remove_pooling = remove_pooling
+ self.remove_eot = remove_eot
+
+ def get_output(self, text):
+ # Based on:
+ # https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L343
+ x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
+
+ x = x + self.positional_embedding
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.ln_final(x)
+
+ if self.remove_pooling:
+ # Mask out tokens which are part of the padding.
+ # Get position of eot token, it has the highest value of all tokens.
+ lengths = text.argmax(dim=-1)
+ if self.remove_eot:
+ # Also mask out the eot token.
+ lengths = lengths - 1
+ indices = torch.arange(x.shape[1], device=text.device)
+ mask = indices.unsqueeze(0) >= lengths
+ x.masked_fill_(mask, 0.0)
+
+ x = x @ self.text_projection
+ else:
+ # Do what is done in the standard clip text encoder.
+ # x.shape = [batch_size, n_ctx, transformer.width]
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
+
+ return x
+
+ @RoutableMixin.route
+ def forward(self, text: torch.Tensor):
+ if self.freeze_model:
+ with torch.no_grad():
+ return self.get_output(text)
+ else:
+ return self.get_output(text)
+
+
+class ClipFeatureExtractor(nn.Module, RoutableMixin):
+ def __init__(
+ self,
+ model_type: str,
+ image_path: Optional[str] = path_defaults.VIDEO,
+ text_path: Optional[str] = path_defaults.TEXT,
+ keep_image_model: bool = True,
+ keep_text_model: bool = True,
+ ):
+ try:
+ import clip
+ except ImportError:
+ raise Exception("Using clip models requires installation with extra `clip`.")
+ nn.Module.__init__(self)
+ RoutableMixin.__init__(self, {"image": image_path, "text": text_path})
+ # Load returns model and preprocessing function, we only want the former.
+ clip_model = clip.load(
+ model_type,
+ # Initially force cpu to ensure tensors are float32 (load routine automatically converts
+ # to half precision if GPUs are detected). We can still do half-precision training via
+ # pytorch lightning if we want to.
+ device="cpu",
+ )[0]
+ # TODO: Continue here
+ clip_model.do_something()
+
+ @RoutableMixin.route
+ def forward(self, image: torch.Tensor, text: torch.Tensor):
+ return {
+ "image": self.clip_model.encode_image(image),
+ "text": self.clip_model.encode_text(text),
+ }
+
+
+class VitFeatureType(enum.Enum):
+ BLOCK = 1
+ KEY = 2
+ VALUE = 3
+ QUERY = 4
+ CLS = 5
+
+
+class VitFeatureHook:
+ """Auxilliary class used to extract features from timm ViT models.
+
+ Args:
+ mode: Type of feature to extract.
+ block: Number of block to extract features from. Note that this is not zero-indexed.
+ """
+
+ def __init__(self, feature_type: VitFeatureType, block: int, drop_cls_token: bool = True):
+ assert isinstance(feature_type, VitFeatureType)
+ self.feature_type = feature_type
+ self.block = block
+ self.drop_cls_token = drop_cls_token
+ self.name = f"{feature_type.name.lower()}{block}"
+ self.remove_handle = None # Can be used to remove this hook from the model again
+
+ self._features = None
+
+ @staticmethod
+ def create_hook_from_feature_level(feature_level: Union[int, str]):
+ feature_level = str(feature_level)
+ prefixes = ("key", "query", "value", "block", "cls")
+ for prefix in prefixes:
+ if feature_level.startswith(prefix):
+ _, _, block = feature_level.partition(prefix)
+ feature_type = VitFeatureType[prefix.upper()]
+ block = int(block)
+ break
+ else:
+ feature_type = VitFeatureType.BLOCK
+ try:
+ block = int(feature_level)
+ except ValueError:
+ raise ValueError(f"Can not interpret feature_level '{feature_level}'.")
+
+ return VitFeatureHook(feature_type, block)
+
+ def register_with(self, model):
+ import timm
+
+ supported_models = (
+ timm.models.vision_transformer.VisionTransformer,
+ timm.models.beit.Beit,
+ )
+ if not isinstance(model, supported_models):
+ raise ValueError(
+ f"This hook only supports classes {', '.join(str(cl) for cl in supported_models)}."
+ )
+
+ if self.block > len(model.blocks):
+ raise ValueError(
+ f"Trying to extract features of block {self.block}, but model only has "
+ f"{len(model.blocks)} blocks"
+ )
+
+ block = model.blocks[self.block - 1]
+ if self.feature_type == VitFeatureType.BLOCK:
+ self.remove_handle = block.register_forward_hook(self)
+ else:
+ if isinstance(block, timm.models.vision_transformer.ParallelBlock):
+ raise ValueError(
+ f"ViT with `ParallelBlock` not supported for {self.feature_type} extraction."
+ )
+ elif isinstance(model, timm.models.beit.Beit):
+ raise ValueError(f"BEIT not supported for {self.feature_type} extraction.")
+ self.remove_handle = block.attn.qkv.register_forward_hook(self)
+
+ return self
+
+ def pop(self) -> torch.Tensor:
+ """Remove and return extracted feature from this hook.
+
+ We only allow access to the features this way to not have any lingering references to them.
+ """
+ assert self._features is not None, "Feature extractor was not called yet!"
+ features = self._features
+ self._features = None
+ return features
+
+ def __call__(self, module, inp, outp):
+ if self.feature_type == VitFeatureType.BLOCK:
+ features = outp
+ if self.drop_cls_token:
+ # First token is CLS token.
+ features = features[:, 1:]
+ elif self.feature_type in {VitFeatureType.KEY, VitFeatureType.QUERY, VitFeatureType.VALUE}:
+ # This part is adapted from the timm implementation. Unfortunately, there is no more
+ # elegant way to access keys, values, or queries.
+ B, N, C = inp[0].shape
+ qkv = outp.reshape(B, N, 3, C) # outp has shape B, N, 3 * H * (C // H)
+ q, k, v = qkv.unbind(2)
+
+ if self.feature_type == VitFeatureType.QUERY:
+ features = q
+ elif self.feature_type == VitFeatureType.KEY:
+ features = k
+ else:
+ features = v
+ if self.drop_cls_token:
+ # First token is CLS token.
+ features = features[:, 1:]
+ elif self.feature_type == VitFeatureType.CLS:
+ # We ignore self.drop_cls_token in this case as it doesn't make any sense.
+ features = outp[:, 0] # Only get class token.
+ else:
+ raise ValueError("Invalid VitFeatureType provided.")
+
+ self._features = features
+
+
+class TimmFeatureExtractor(ImageFeatureExtractor):
+ """Feature extractor implementation for timm models.
+
+ Args:
+ model_name: Name of model. See `timm.list_models("*")` for available options.
+ feature_level: Level of features to return. For CNN-based models, a single integer. For ViT
+ models, either a single or a list of feature descriptors. If a list is passed, multiple
+ levels of features are extracted and concatenated. A ViT feature descriptor consists of
+ the type of feature to extract, followed by an integer indicating the ViT block whose
+ features to use. The type of features can be one of "block", "key", "query", "value",
+ specifying that the block's output, attention keys, query or value should be used. If
+ omitted, assumes "block" as the type. Example: "block1" or ["block1", "value2"].
+ aux_features: Features to store as auxilliary features. The format is the same as in the
+ `feature_level` argument. Features are stored as a dictionary, using their string
+ representation (e.g. "block1") as the key. Only valid for ViT models.
+ pretrained: Whether to load pretrained weights.
+ freeze: Whether the weights of the feature extractor should be trainable.
+ n_blocks_to_unfreeze: Number of blocks that should be trainable, beginning from the last
+ block.
+ unfreeze_attention: Whether weights of ViT attention layers should be trainable (only valid
+ for ViT models). According to http://arxiv.org/abs/2203.09795, finetuning attention
+ layers only can yield better results in some cases, while being slightly cheaper in terms
+ of computation and memory.
+ """
+
+ def __init__(
+ self,
+ model_name: str,
+ feature_level: Optional[Union[int, str, List[Union[int, str]]]] = None,
+ aux_features: Optional[Union[int, str, List[Union[int, str]]]] = None,
+ pretrained: bool = False,
+ freeze: bool = False,
+ n_blocks_to_unfreeze: int = 0,
+ unfreeze_attention: bool = False,
+ video_path: Optional[str] = path_defaults.VIDEO,
+ ):
+ super().__init__(video_path)
+ try:
+ import timm
+ except ImportError:
+ raise Exception("Using timm models requires installation with extra `timm`.")
+
+ register_custom_timm_models()
+
+ self.is_vit = model_name.startswith("vit") or model_name.startswith("beit")
+
+ def feature_level_to_list(feature_level):
+ if feature_level is None:
+ return []
+ elif isinstance(feature_level, (int, str)):
+ return [feature_level]
+ else:
+ return list(feature_level)
+
+ self.feature_levels = feature_level_to_list(feature_level)
+ self.aux_features = feature_level_to_list(aux_features)
+
+ if self.is_vit:
+ model = timm.create_model(model_name, pretrained=pretrained)
+ # Delete unused parameters from classification head
+ if hasattr(model, "head"):
+ del model.head
+ if hasattr(model, "fc_norm"):
+ del model.fc_norm
+
+ if len(self.feature_levels) > 0 or len(self.aux_features) > 0:
+ self._feature_hooks = [
+ VitFeatureHook.create_hook_from_feature_level(level).register_with(model)
+ for level in itertools.chain(self.feature_levels, self.aux_features)
+ ]
+ if len(self.feature_levels) > 0:
+ feature_dim = model.num_features * len(self.feature_levels)
+
+ # Remove modules not needed in computation of features
+ max_block = max(hook.block for hook in self._feature_hooks)
+ new_blocks = model.blocks[:max_block] # Creates a copy
+ del model.blocks
+ model.blocks = new_blocks
+ model.norm = nn.Identity()
+ else:
+ feature_dim = model.num_features
+ else:
+ self._feature_hooks = None
+ feature_dim = model.num_features
+ else:
+ if len(self.feature_levels) == 0:
+ raise ValueError(
+ f"Feature extractor {model_name} requires specifying `feature_level`"
+ )
+ elif len(self.feature_levels) != 1:
+ raise ValueError(
+ f"Feature extractor {model_name} only supports a single `feature_level`"
+ )
+ elif not isinstance(self.feature_levels[0], int):
+ raise ValueError("`feature_level` needs to be an integer")
+
+ if len(self.aux_features) > 0:
+ raise ValueError("`aux_features` not supported by feature extractor {model_name}")
+
+ model = timm.create_model(
+ model_name,
+ pretrained=pretrained,
+ features_only=True,
+ out_indices=self.feature_levels,
+ )
+ feature_dim = model.feature_info.channels()[0]
+
+ self.model = model
+ self.freeze = freeze
+ self.n_blocks_to_unfreeze = n_blocks_to_unfreeze
+ self._feature_dim = feature_dim
+
+ if freeze:
+ self.model.requires_grad_(False)
+ # BatchNorm layers update their statistics in train mode. This is probably not desired
+ # when the model is supposed to be frozen.
+ contains_bn = any(
+ isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d))
+ for m in self.model.modules()
+ )
+ self.run_in_eval_mode = contains_bn
+ else:
+ self.run_in_eval_mode = False
+
+ if self.n_blocks_to_unfreeze > 0:
+ if not self.is_vit:
+ raise NotImplementedError(
+ "`unfreeze_n_blocks` option only implemented for ViT models"
+ )
+ self.model.blocks[-self.n_blocks_to_unfreeze :].requires_grad_(True)
+ if self.model.norm is not None:
+ self.model.norm.requires_grad_(True)
+
+ if unfreeze_attention:
+ if not self.is_vit:
+ raise ValueError("`unfreeze_attention` option only works with ViT models")
+ for module in self.model.modules():
+ if isinstance(module, timm.models.vision_transformer.Attention):
+ module.requires_grad_(True)
+
+ @property
+ def feature_dim(self):
+ return self._feature_dim
+
+ def forward_images(self, images: torch.Tensor):
+ if self.run_in_eval_mode and self.training:
+ self.eval()
+
+ if self.is_vit:
+ if self.freeze and self.n_blocks_to_unfreeze == 0:
+ # Speed things up a bit by not requiring grad computation.
+ with torch.no_grad():
+ features = self.model.forward_features(images)
+ else:
+ features = self.model.forward_features(images)
+
+ if self._feature_hooks is not None:
+ hook_features = [hook.pop() for hook in self._feature_hooks]
+
+ if len(self.feature_levels) == 0:
+ # Remove class token when not using hooks.
+ features = features[:, 1:]
+ positions = transformer_compute_positions(features)
+ else:
+ features = hook_features[: len(self.feature_levels)]
+ positions = transformer_compute_positions(features[0])
+ features = torch.cat(features, dim=-1)
+
+ if len(self.aux_features) > 0:
+ aux_hooks = self._feature_hooks[len(self.feature_levels) :]
+ aux_features = hook_features[len(self.feature_levels) :]
+ aux_features = {hook.name: feat for hook, feat in zip(aux_hooks, aux_features)}
+ else:
+ aux_features = None
+ else:
+ features = self.model(images)[0]
+ positions, features = cnn_compute_positions_and_flatten(features)
+ aux_features = None
+
+ return positions, features, aux_features
+
+
+class SlotAttentionFeatureExtractor(ImageFeatureExtractor):
+ """Feature extractor as used in slot attention paper."""
+
+ def __init__(self, video_path: Optional[str] = path_defaults.VIDEO):
+ super().__init__(video_path)
+ self.layers = nn.Sequential(
+ nn.Conv2d(3, out_channels=64, kernel_size=5, padding="same"),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(64, out_channels=64, kernel_size=5, padding="same"),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(64, out_channels=64, kernel_size=5, padding="same"),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(64, out_channels=64, kernel_size=5, padding="same"),
+ nn.ReLU(inplace=True),
+ )
+
+ @property
+ def feature_dim(self):
+ return 64
+
+ def forward_images(self, images: torch.Tensor):
+ features = self.layers(images)
+ positions, flattened = cnn_compute_positions_and_flatten(features)
+ return positions, flattened
+
+
+class SAViFeatureExtractor(ImageFeatureExtractor):
+ """Feature extractor as used in the slot attention for video paper."""
+
+ def __init__(self, larger_input_arch=False, video_path: Optional[str] = path_defaults.VIDEO):
+ """Feature extractor as used in the slot attention for video paper.
+
+ Args:
+ larger_input_arch: Use the architecture for larger image datasets such as MOVi++, which
+ contains more a stride in the first layer and a higher number of feature channels in
+ the CNN backbone.
+ video_path: Path of input video or also image.
+ """
+ super().__init__(video_path=video_path)
+ self.larger_input_arch = larger_input_arch
+ if larger_input_arch:
+ self.layers = nn.Sequential(
+ # Pytorch does not support stride>1 with padding=same.
+ # Implement tensorflow behaviour manually.
+ # See: https://discuss.pytorch.org/t/same-padding-equivalent-in-pytorch/85121/4
+ nn.ZeroPad2d((1, 2, 1, 2)),
+ nn.Conv2d(3, out_channels=64, kernel_size=5, stride=2, padding="valid"),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(64, out_channels=64, kernel_size=5, padding="same"),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(64, out_channels=64, kernel_size=5, padding="same"),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(64, out_channels=64, kernel_size=5, padding="same"),
+ )
+ else:
+ self.layers = nn.Sequential(
+ nn.Conv2d(3, out_channels=32, kernel_size=5, padding="same"),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(32, out_channels=32, kernel_size=5, padding="same"),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(32, out_channels=32, kernel_size=5, padding="same"),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(32, out_channels=32, kernel_size=5, padding="same"),
+ )
+
+ @property
+ def feature_dim(self):
+ return 64 if self.larger_input_arch else 32
+
+ def forward_images(self, images: torch.Tensor):
+ features = self.layers(images)
+ positions, flattened = cnn_compute_positions_and_flatten(features)
+ return positions, flattened
+
+
+def register_custom_timm_models():
+ import timm
+ from timm.models import layers, resnet, vision_transformer
+ from timm.models.helpers import build_model_with_cfg, resolve_pretrained_cfg
+
+ @timm.models.registry.register_model
+ def resnet34_savi(pretrained=False, **kwargs):
+ """ResNet34 as used in SAVi and SAVi++.
+
+ As of now, no official code including the ResNet was released, so we can only guess which of
+ the numerous ResNet variants was used. This modifies the basic timm ResNet34 to have 1x1
+ strides in the stem, and replaces batch norm with group norm. Gives 16x16 feature maps with
+ an input size of 224x224.
+
+ From SAVi:
+ > For the modified SAVi (ResNet) model on MOVi++, we replace the convolutional backbone [...]
+ > with a ResNet-34 backbone. We use a modified ResNet root block without strides
+ > (i.e. 1Ă—1 stride), resulting in 16Ă—16 feature maps after the backbone [w. 128x128 images].
+ > We further use group normalization throughout the ResNet backbone.
+
+ From SAVi++:
+ > We used a ResNet-34 backbone with modified root convolutional layer that has 1Ă—1 stride.
+ > For all layers, we replaced the batch normalization operation by group normalization.
+ """
+ if pretrained:
+ raise ValueError("No pretrained weights available for `savi_resnet34`.")
+
+ model_args = dict(
+ block=resnet.BasicBlock, layers=[3, 4, 6, 3], norm_layer=layers.GroupNorm, **kwargs
+ )
+ model = resnet._create_resnet("resnet34", pretrained=pretrained, **model_args)
+ model.conv1.stride = (1, 1)
+ model.maxpool.stride = (1, 1)
+ return model
+
+ @timm.models.registry.register_model
+ def resnet50_dino(pretrained=False, **kwargs):
+ kwargs["pretrained_cfg"] = resnet._cfg(
+ url=(
+ "https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/"
+ "dino_resnet50_pretrain.pth"
+ )
+ )
+ model_args = dict(block=resnet.Bottleneck, layers=[3, 4, 6, 3], **kwargs)
+ return build_model_with_cfg(resnet.ResNet, "resnet50_dino", pretrained, **model_args)
+
+ def add_moco_positional_embedding(model, temperature=10000.0):
+ """Moco ViT uses 2d sincos embedding."""
+ h, w = model.patch_embed.grid_size
+ grid_w = torch.arange(w, dtype=torch.float32)
+ grid_h = torch.arange(h, dtype=torch.float32)
+ grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
+ assert (
+ model.embed_dim % 4 == 0
+ ), "Embed dimension must be divisible by 4 for 2D sin-cos position embedding"
+ pos_dim = model.embed_dim // 4
+ omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
+ omega = 1.0 / (temperature**omega)
+ out_w = torch.einsum("m,d->md", [grid_w.flatten(), omega])
+ out_h = torch.einsum("m,d->md", [grid_h.flatten(), omega])
+ pos_emb = torch.cat(
+ [torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1
+ )[None, :, :]
+ if hasattr(model, "num_tokens"): # Old timm versions
+ assert model.num_tokens == 1, "Assuming one and only one token, [cls]"
+ else:
+ assert model.num_prefix_tokens == 1, "Assuming one and only one token, [cls]"
+ pe_token = torch.zeros([1, 1, model.embed_dim], dtype=torch.float32)
+ model.pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1))
+ model.pos_embed.requires_grad = False
+
+ def moco_checkpoint_filter_fn(state_dict, model, linear_name):
+ state_dict = state_dict["state_dict"]
+
+ for k in list(state_dict.keys()):
+ # retain only base_encoder up to before the embedding layer
+ if k.startswith("module.base_encoder") and not k.startswith(
+ f"module.base_encoder.{linear_name}"
+ ):
+ # remove prefix
+ state_dict[k[len("module.base_encoder.") :]] = state_dict[k]
+ # delete renamed or unused k
+ del state_dict[k]
+
+ return state_dict
+
+ def create_moco_vit(variant, pretrained=False, **kwargs):
+ if kwargs.get("features_only", None):
+ raise RuntimeError("features_only not implemented for Vision Transformer models.")
+
+ pretrained_cfg = resolve_pretrained_cfg(
+ variant, pretrained_cfg=kwargs.pop("pretrained_cfg", None)
+ )
+ model = build_model_with_cfg(
+ vision_transformer.VisionTransformer,
+ variant,
+ pretrained,
+ pretrained_cfg=pretrained_cfg,
+ pretrained_filter_fn=partial(moco_checkpoint_filter_fn, linear_name="head"),
+ pretrained_custom_load=False,
+ **kwargs,
+ )
+ add_moco_positional_embedding(model)
+ return model
+
+ @timm.models.registry.register_model
+ def vit_small_patch16_224_mocov3(pretrained=False, **kwargs):
+ kwargs["pretrained_cfg"] = vision_transformer._cfg(
+ url="https://dl.fbaipublicfiles.com/moco-v3/vit-s-300ep/vit-s-300ep.pth.tar"
+ )
+ model_kwargs = dict(
+ patch_size=16,
+ embed_dim=384,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ num_classes=0,
+ **kwargs,
+ )
+ model = create_moco_vit(
+ "vit_small_patch16_224_mocov3", pretrained=pretrained, **model_kwargs
+ )
+ return model
+
+ @timm.models.registry.register_model
+ def vit_base_patch16_224_mocov3(pretrained=False, **kwargs):
+ kwargs["pretrained_cfg"] = vision_transformer._cfg(
+ url="https://dl.fbaipublicfiles.com/moco-v3/vit-b-300ep/vit-b-300ep.pth.tar"
+ )
+ model_kwargs = dict(
+ patch_size=16,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ num_classes=0,
+ **kwargs,
+ )
+ model = create_moco_vit("vit_base_patch16_224_mocov3", pretrained=pretrained, **model_kwargs)
+ return model
+
+ @timm.models.registry.register_model
+ def resnet50_mocov3(pretrained=False, **kwargs):
+ kwargs["pretrained_cfg"] = resnet._cfg(
+ url="https://dl.fbaipublicfiles.com/moco-v3/r-50-1000ep/r-50-1000ep.pth.tar"
+ )
+ model_args = dict(block=resnet.Bottleneck, layers=[3, 4, 6, 3], **kwargs)
+ return build_model_with_cfg(
+ resnet.ResNet,
+ "resnet50_mocov3",
+ pretrained,
+ pretrained_filter_fn=partial(moco_checkpoint_filter_fn, linear_name="fc"),
+ **model_args,
+ )
+
+ def msn_vit_checkpoint_filter_fn(state_dict, model):
+ state_dict = state_dict["target_encoder"]
+
+ for k in list(state_dict.keys()):
+ if not k.startswith("module.fc."):
+ # remove prefix
+ state_dict[k[len("module.") :]] = state_dict[k]
+ # delete renamed or unused k
+ del state_dict[k]
+
+ return state_dict
+
+ def create_msn_vit(variant, pretrained=False, **kwargs):
+ if kwargs.get("features_only", None):
+ raise RuntimeError("features_only not implemented for Vision Transformer models.")
+
+ pretrained_cfg = resolve_pretrained_cfg(
+ variant, pretrained_cfg=kwargs.pop("pretrained_cfg", None)
+ )
+ model = build_model_with_cfg(
+ vision_transformer.VisionTransformer,
+ variant,
+ pretrained,
+ pretrained_cfg=pretrained_cfg,
+ pretrained_filter_fn=msn_vit_checkpoint_filter_fn,
+ pretrained_custom_load=False,
+ **kwargs,
+ )
+ return model
+
+ @timm.models.registry.register_model
+ def vit_small_patch16_224_msn(pretrained=False, **kwargs):
+ kwargs["pretrained_cfg"] = vision_transformer._cfg(
+ url="https://dl.fbaipublicfiles.com/msn/vits16_800ep.pth.tar"
+ )
+ model_kwargs = dict(
+ patch_size=16,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ num_classes=0,
+ **kwargs,
+ )
+ model = create_msn_vit("vit_small_patch16_224_msn", pretrained=pretrained, **model_kwargs)
+ return model
+
+ @timm.models.registry.register_model
+ def vit_base_patch16_224_msn(pretrained=False, **kwargs):
+ kwargs["pretrained_cfg"] = vision_transformer._cfg(
+ url="https://dl.fbaipublicfiles.com/msn/vitb16_600ep.pth.tar"
+ )
+ model_kwargs = dict(
+ patch_size=16,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ num_classes=0,
+ **kwargs,
+ )
+ model = create_msn_vit("vit_base_patch16_224_msn", pretrained=pretrained, **model_kwargs)
+ return model
+
+ @timm.models.registry.register_model
+ def vit_base_patch16_224_mae(pretrained=False, **kwargs):
+ from timm.models.vision_transformer import _create_vision_transformer
+
+ kwargs["pretrained_cfg"] = vision_transformer._cfg(
+ url="https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth"
+ )
+ model_kwargs = dict(
+ patch_size=16,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ num_classes=0,
+ **kwargs,
+ )
+ model = _create_vision_transformer(
+ "vit_base_patch16_224_mae", pretrained=pretrained, **model_kwargs
+ )
+ return model
+
+
+class DVAEFeatureExtractor(ImageFeatureExtractor):
+ """DVAE VQ Encoder in SLATE."""
+
+ def __init__(
+ self,
+ encoder: nn.Module,
+ positional_encoder: nn.Module,
+ dictionary: nn.Module,
+ tau: Callable,
+ hard: bool = False,
+ video_path: Optional[str] = path_defaults.VIDEO,
+ ):
+ """Feature extractor as used in the SLATE paper.
+
+ Args:
+ encoder: torch Module that transforms image to the patch representations.
+ positional_encoder: torch Module that adds pos encoding.
+ dictionary: map from onehot vectors to embeddings.
+ tau: temporature for gumbel_softmax.
+ hard: hard gumbel_softmax if True.
+ video_path: path to original inputs.
+ """
+ super().__init__(video_path)
+ self.global_step = None
+ self.tau = tau
+ self.hard = hard
+ self.dictionary = dictionary
+ self.positional_encoder = positional_encoder
+ self.encoder = encoder
+
+ @property
+ def feature_dim(self):
+ return 64
+
+ def forward_images(self, images: torch.Tensor):
+ z_logits = nn.functional.log_softmax(self.encoder(images), dim=1)
+ _, _, H_enc, W_enc = z_logits.size()
+ z = nn.functional.gumbel_softmax(z_logits, self.tau(self.global_step), self.hard, dim=1)
+ z_hard = nn.functional.gumbel_softmax(
+ z_logits, self.tau(self.global_step), True, dim=1
+ ).detach()
+
+ # add beginning of sequence (BOS) token
+ # [1, 0, 0, 0, ...] is encoding for BOS token
+ # and each sequence starts from such token
+ z_hard = z_hard.permute(0, 2, 3, 1).flatten(start_dim=1, end_dim=2)
+ # add first zeros column to the z_hard matrix
+ z_transformer_input = torch.cat([torch.zeros_like(z_hard[..., :1]), z_hard], dim=-1)
+ # add first zeros row to the z_hard matrix
+ z_transformer_input = torch.cat(
+ [torch.zeros_like(z_transformer_input[..., :1, :]), z_transformer_input], dim=-2
+ )
+ # fill new row and column with one,
+ # so that we added [1, 0, 0, 0, ...] token
+ z_transformer_input[:, 0, 0] = 1.0
+
+ # tokens to embeddings
+ features = self.dictionary(z_transformer_input)
+ features = self.positional_encoder(features)
+
+ slot_attention_features = features[:, 1:]
+
+ transformer_input = features[:, :-1]
+ aux_features = {
+ "z": z,
+ "targets": transformer_input,
+ "z_hard": z_hard,
+ }
+ return None, slot_attention_features, aux_features
+
+ @RoutableMixin.route
+ def forward(self, video: torch.Tensor, global_step: int) -> base.FeatureExtractorOutput:
+ self.global_step = global_step
+ return super().forward(video=video)
diff --git a/ocl/hooks.py b/ocl/hooks.py
new file mode 100644
index 0000000..b1f115a
--- /dev/null
+++ b/ocl/hooks.py
@@ -0,0 +1,76 @@
+from typing import Any, Callable, Dict, Tuple
+
+import webdataset
+from pluggy import HookimplMarker, HookspecMarker
+
+from ocl.combined_model import CombinedModel
+
+hook_specification = HookspecMarker("ocl")
+hook_implementation = HookimplMarker("ocl")
+
+
+class FakeHooks:
+ """Class that mimics the behavior of the plugin manager hooks property."""
+
+ def __getattr__(self, attribute):
+ """Return a fake hook handler for any attribute query."""
+
+ def fake_hook_handler(*args, **kwargs):
+ return tuple()
+
+ return fake_hook_handler
+
+
+# @transform_hooks
+# def input_dependencies() -> Tuple[str, ...]:
+# """Provide list of variables that are required for the plugin to function."""
+#
+#
+# @transform_hooks
+# def provided_inputs() -> Tuple[str, ...]:
+# """Provide list of variables that are provided by the plugin."""
+
+
+@hook_specification
+def training_transform() -> Callable[[webdataset.Processor], webdataset.Processor]:
+ """Provide a transformation which processes a component of a webdataset pipeline."""
+
+
+@hook_specification
+def training_batch_transform() -> Callable[[webdataset.Processor], webdataset.Processor]:
+ """Provide a transformation which processes a batched component of a webdataset pipeline."""
+
+
+@hook_specification
+def training_fields() -> Tuple[str]:
+ """Provide list of fields that are required to be decoded during training."""
+
+
+@hook_specification
+def evaluation_transform() -> Callable[[webdataset.Processor], webdataset.Processor]:
+ """Provide a transformation which processes a component of a webdataset pipeline."""
+
+
+@hook_specification
+def evaluation_batch_transform() -> Callable[[webdataset.Processor], webdataset.Processor]:
+ """Provide a transformation which processes a batched component of a webdataset pipeline."""
+
+
+@hook_specification
+def evaluation_fields() -> Tuple[str]:
+ """Provide list of fields that are required to be decoded during evaluation."""
+
+
+@hook_specification
+def configure_optimizers(model: CombinedModel) -> Dict[str, Any]:
+ """Return optimizers in the format of pytorch lightning."""
+
+
+@hook_specification
+def on_train_start(model: CombinedModel) -> None:
+ """Hook called when starting training."""
+
+
+@hook_specification
+def on_train_epoch_start(model: CombinedModel) -> None:
+ """Hook called when starting training epoch."""
diff --git a/ocl/losses.py b/ocl/losses.py
new file mode 100644
index 0000000..573bd33
--- /dev/null
+++ b/ocl/losses.py
@@ -0,0 +1,1010 @@
+from functools import partial
+from math import log
+from typing import Callable, List, Optional, Union
+
+import pytorch_lightning as pl
+import torch
+from einops import rearrange, repeat
+from torch import nn
+from torch.nn import functional as F
+from torchvision import transforms
+from torchvision.ops import generalized_box_iou
+
+from ocl import base, consistency, path_defaults, scheduling
+from ocl.base import Instances
+from ocl.matching import CPUHungarianMatcher
+from ocl.utils.bboxes import box_cxcywh_to_xyxy
+from ocl.utils.routing import RoutableMixin
+
+
+def _constant_weight(weight: float, global_step: int):
+ return weight
+
+
+class ReconstructionLoss(nn.Module, RoutableMixin):
+ def __init__(
+ self,
+ loss_type: str,
+ weight: Union[Callable, float] = 1.0,
+ normalize_target: bool = False,
+ input_path: Optional[str] = None,
+ target_path: Optional[str] = None,
+ ):
+ nn.Module.__init__(self)
+ RoutableMixin.__init__(
+ self,
+ {"input": input_path, "target": target_path, "global_step": path_defaults.GLOBAL_STEP},
+ )
+ if loss_type == "mse":
+ self.loss_fn = nn.functional.mse_loss
+ elif loss_type == "mse_sum":
+ # Used for slot_attention and video slot attention.
+ self.loss_fn = (
+ lambda x1, x2: nn.functional.mse_loss(x1, x2, reduction="sum") / x1.shape[0]
+ )
+ elif loss_type == "l1":
+ self.loss_name = "l1_loss"
+ self.loss_fn = nn.functional.l1_loss
+ elif loss_type == "cosine":
+ self.loss_name = "cosine_loss"
+ self.loss_fn = lambda x1, x2: -nn.functional.cosine_similarity(x1, x2, dim=-1).mean()
+ elif loss_type == "cross_entropy_sum":
+ # Used for SLATE, average is over the first (batch) dim only.
+ self.loss_name = "cross_entropy_sum_loss"
+ self.loss_fn = (
+ lambda x1, x2: nn.functional.cross_entropy(
+ x1.reshape(-1, x1.shape[-1]), x2.reshape(-1, x2.shape[-1]), reduction="sum"
+ )
+ / x1.shape[0]
+ )
+ else:
+ raise ValueError(
+ f"Unknown loss {loss_type}. Valid choices are (mse, l1, cosine, cross_entropy)."
+ )
+ # If weight is callable use it to determine scheduling otherwise use constant value.
+ self.weight = weight if callable(weight) else partial(_constant_weight, weight)
+ self.normalize_target = normalize_target
+
+ @RoutableMixin.route
+ def forward(self, input: torch.Tensor, target: torch.Tensor, global_step: int):
+ target = target.detach()
+ if self.normalize_target:
+ mean = target.mean(dim=-1, keepdim=True)
+ var = target.var(dim=-1, keepdim=True)
+ target = (target - mean) / (var + 1.0e-6) ** 0.5
+
+ loss = self.loss_fn(input, target)
+ weight = self.weight(global_step)
+ return weight * loss
+
+class SparsePenalty(nn.Module, RoutableMixin):
+ def __init__(
+ self,
+ linear_weight: Union[Callable, float] = 1.0,
+ quadratic_weight: Union[Callable, float] = 0.0,
+ quadratic_bias: Union[Callable, float] = 0.5,
+ input_path: Optional[str] = None,
+ ):
+ nn.Module.__init__(self)
+ RoutableMixin.__init__(
+ self,
+ {"input": input_path, "global_step": path_defaults.GLOBAL_STEP},
+ )
+
+ self.linear_weight = linear_weight if callable(linear_weight) else partial(_constant_weight, linear_weight)
+ self.quadratic_weight = quadratic_weight if callable(quadratic_weight) else partial(_constant_weight, quadratic_weight)
+ self.quadratic_bias = quadratic_bias if callable(quadratic_bias) else partial(_constant_weight, quadratic_bias)
+
+ @RoutableMixin.route
+ def forward(self, input: torch.Tensor, global_step: int):
+ # print("spaese_input.shape")
+ # print(input.shape)
+ sparse_degree = torch.mean(input)
+
+ linear_weight = self.linear_weight(global_step)
+ quadratic_weight = self.quadratic_weight(global_step)
+ quadratic_bias = self.quadratic_bias(global_step)
+
+ return linear_weight * sparse_degree + quadratic_weight * (sparse_degree - quadratic_bias) ** 2
+ #return linear_weight * sparse_degree + quadratic_weight * torch.mean((input - quadratic_bias) ** 2)
+
+
+
+class LatentDupplicateSuppressionLoss(nn.Module, RoutableMixin):
+ def __init__(
+ self,
+ weight: Union[float, scheduling.HPSchedulerT],
+ eps: float = 1e-08,
+ grouping_path: Optional[str] = "perceptual_grouping",
+ ):
+ nn.Module.__init__(self)
+ RoutableMixin.__init__(
+ self, {"grouping": grouping_path, "global_step": path_defaults.GLOBAL_STEP}
+ )
+ self.weight = weight
+ self.similarity = nn.CosineSimilarity(dim=-1, eps=eps)
+
+ @RoutableMixin.route
+ def forward(self, grouping: base.PerceptualGroupingOutput, global_step: int):
+ if grouping.objects.dim() == 4:
+ # Build large tensor of reconstructed video.
+ objects = grouping.objects
+ bs, n_frames, n_objects, n_features = objects.shape
+
+ off_diag_indices = torch.triu_indices(
+ n_objects, n_objects, offset=1, device=objects.device
+ )
+
+ sq_similarities = (
+ self.similarity(
+ objects[:, :, off_diag_indices[0], :], objects[:, :, off_diag_indices[1], :]
+ )
+ ** 2
+ )
+
+ if grouping.is_empty is not None:
+ p_not_empty = 1.0 - grouping.is_empty
+ # Assume that the probability of of individual objects being present is independent,
+ # thus the probability of both being present is the product of the individual
+ # probabilities.
+ p_pair_present = (
+ p_not_empty[..., off_diag_indices[0]] * p_not_empty[..., off_diag_indices[1]]
+ )
+ # Use average expected penalty as loss for each frame.
+ losses = (sq_similarities * p_pair_present) / torch.sum(
+ p_pair_present, dim=-1, keepdim=True
+ )
+ else:
+ losses = sq_similarities.mean(dim=-1)
+
+ weight = self.weight(global_step) if callable(self.weight) else self.weight
+ return weight * losses.sum() / (bs * n_frames)
+ elif grouping.objects.dim() == 3:
+ # Build large tensor of reconstructed image.
+ objects = grouping.objects
+ bs, n_objects, n_features = objects.shape
+
+ off_diag_indices = torch.triu_indices(
+ n_objects, n_objects, offset=1, device=objects.device
+ )
+
+ sq_similarities = (
+ self.similarity(
+ objects[:, off_diag_indices[0], :], objects[:, off_diag_indices[1], :]
+ )
+ ** 2
+ )
+
+ if grouping.is_empty is not None:
+ p_not_empty = 1.0 - grouping.is_empty
+ # Assume that the probability of of individual objects being present is independent,
+ # thus the probability of both being present is the product of the individual
+ # probabilities.
+ p_pair_present = (
+ p_not_empty[..., off_diag_indices[0]] * p_not_empty[..., off_diag_indices[1]]
+ )
+ # Use average expected penalty as loss for each frame.
+ losses = (sq_similarities * p_pair_present) / torch.sum(
+ p_pair_present, dim=-1, keepdim=True
+ )
+ else:
+ losses = sq_similarities.mean(dim=-1)
+
+ weight = self.weight(global_step) if callable(self.weight) else self.weight
+ return weight * losses.sum() / bs
+ else:
+ raise ValueError("Incompatible input format.")
+
+
+class ConsistencyLoss(nn.Module, RoutableMixin):
+ """Task that returns the previously extracted objects.
+
+ Intended to make the object representations accessible to downstream functions, e.g. metrics.
+ """
+
+ def __init__(
+ self,
+ matcher: consistency.HungarianMatcher,
+ loss_type: str = "CE",
+ loss_weight: float = 0.25,
+ mask_path: Optional[str] = None,
+ mask_target_path: Optional[str] = None,
+ params_path: Optional[str] = None,
+ ):
+ nn.Module.__init__(self)
+ RoutableMixin.__init__(
+ self,
+ {
+ "mask": mask_path,
+ "mask_target": mask_target_path,
+ "cropping_params": params_path,
+ "global_step": path_defaults.GLOBAL_STEP,
+ },
+ )
+ self.matcher = matcher
+ if loss_type == "CE":
+ self.loss_name = "masks_consistency_CE"
+ self.weight = (
+ loss_weight if callable(loss_weight) else partial(_constant_weight, loss_weight)
+ )
+ self.loss_fn = nn.CrossEntropyLoss()
+
+ @RoutableMixin.route
+ def forward(
+ self,
+ mask: torch.Tensor,
+ mask_target: torch.Tensor,
+ cropping_params: torch.Tensor,
+ global_step: int,
+ ):
+ _, n_objects, size, _ = mask.shape
+ mask_one_hot = self._to_binary_mask(mask)
+ mask_target = self.crop_views(mask_target, cropping_params, size)
+ mask_target_one_hot = self._to_binary_mask(mask_target)
+ match = self.matcher(mask_one_hot, mask_target_one_hot)
+ matched_mask = torch.stack([mask[match[i, 1]] for i, mask in enumerate(mask)])
+ assert matched_mask.shape == mask.shape
+ assert mask_target.shape == mask.shape
+ flattened_matched_mask = matched_mask.permute(0, 2, 3, 1).reshape(-1, n_objects)
+ flattened_mask_target = mask_target.permute(0, 2, 3, 1).reshape(-1, n_objects)
+ weight = self.weight(global_step) if callable(self.weight) else self.weight
+ return weight * self.loss_fn(flattened_matched_mask, flattened_mask_target)
+
+ @staticmethod
+ def _to_binary_mask(masks: torch.Tensor):
+ _, n_objects, _, _ = masks.shape
+ m_lables = masks.argmax(dim=1)
+ mask_one_hot = torch.nn.functional.one_hot(m_lables, n_objects)
+ return mask_one_hot.permute(0, 3, 1, 2)
+
+ def crop_views(self, view: torch.Tensor, param: torch.Tensor, size: int):
+ return torch.cat([self.crop_maping(v, p, size) for v, p in zip(view, param)])
+
+ @staticmethod
+ def crop_maping(view: torch.Tensor, p: torch.Tensor, size: int):
+ p = tuple(p.cpu().numpy().astype(int))
+ return transforms.functional.resized_crop(view, *p, size=(size, size))[None]
+
+
+def focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2, mean_in_dim1=True):
+ """Loss used in RetinaNet for dense detection. # noqa: D411.
+
+ Args:
+ inputs: A float tensor of arbitrary shape.
+ The predictions for each example.
+ targets: A float tensor with the same shape as inputs. Stores the binary
+ classification label for each element in inputs
+ (0 for the negative class and 1 for the positive class).
+ alpha: (optional) Weighting factor in range (0,1) to balance
+ positive vs negative examples. Default = -1 (no weighting).
+ gamma: Exponent of the modulating factor (1 - p_t) to
+ balance easy vs hard examples.
+ Returns:
+ Loss tensor
+ """
+ prob = inputs
+ ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
+ p_t = prob * targets + (1 - prob) * (1 - targets)
+ loss = ce_loss * ((1 - p_t) ** gamma)
+
+ if alpha >= 0:
+ alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
+ loss = alpha_t * loss
+ if mean_in_dim1:
+ return loss.mean(1).sum() / num_boxes
+ else:
+ return loss.sum() / num_boxes
+
+
+def CompDETRCostMatrix(
+ outputs,
+ targets,
+ use_focal=True,
+ class_weight: float = 1,
+ bbox_weight: float = 1,
+ giou_weight: float = 1,
+):
+ """Compute cost matrix between outputs instances and target instances.
+
+ Params:
+ outputs: This is a dict that contains at least these entries:
+ "pred_logits": Tensor of dim [batch_size, num_queries, num_classes]
+ with the classification logits
+ "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the
+ predicted box coordinates
+
+ targets: a list of targets (len(targets) = batch_size), where each target is a instance:
+ "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of
+ ground-truth objects in the target) containing the class labels
+ "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
+
+ Returns:
+ costMatrix: A iter of tensors of size [num_outputs, num_targets].
+ """
+ with torch.no_grad():
+ bs, num_queries = outputs["pred_logits"].shape[:2]
+
+ # We flatten to compute the cost matrices in a batch
+ if use_focal:
+ out_prob = outputs["pred_logits"].flatten(0, 1)
+ else:
+ AssertionError("only support focal for now.")
+ out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
+
+ # Also concat the target labels and boxes
+ tgt_ids = torch.cat([v["labels"] for v in targets])
+ tgt_bbox = torch.cat([v["boxes"] for v in targets])
+
+ # Compute the classification cost.
+ if use_focal:
+ alpha = 0.25
+ gamma = 2.0
+ neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log())
+ pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
+ cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]
+ else:
+ # Compute the classification cost. Contrary to the loss, we don't use the NLL,
+ # but approximate it in 1 - proba[target class].
+ # The 1 is a constant that doesn't change the matching, it can be ommitted.
+ cost_class = -out_prob[:, tgt_ids]
+
+ # Compute the L1 cost between boxes
+ cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
+
+ # Compute the giou cost betwen boxes
+ cost_giou = -generalized_box_iou( # noqa: F821
+ box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox) # noqa: F821
+ )
+
+ # Final cost matrix
+ C = bbox_weight * cost_bbox + class_weight * cost_class + giou_weight * cost_giou
+ C = C.view(bs, num_queries, -1).cpu()
+
+ sizes = [len(v["boxes"]) for v in targets]
+
+ return C.split(sizes, -1)
+
+
+class MOTRLoss(nn.Module, RoutableMixin):
+ def __init__(
+ self,
+ num_classes: int = 1,
+ loss_weight: float = 1.0,
+ input_bbox_path: Optional[str] = None,
+ target_bbox_path: Optional[str] = None,
+ input_cls_path: Optional[str] = None,
+ target_cls_path: Optional[str] = None,
+ target_id_path: Optional[str] = None,
+ ):
+ nn.Module.__init__(self)
+ self.num_classes = num_classes
+ self.loss_weight = loss_weight
+ RoutableMixin.__init__(
+ self,
+ {
+ "input_bbox": input_bbox_path,
+ "target_bbox": target_bbox_path,
+ "input_cls": input_cls_path,
+ "target_cls": target_cls_path,
+ "target_id": target_id_path,
+ "global_step": path_defaults.GLOBAL_STEP,
+ },
+ )
+ self.matcher = CPUHungarianMatcher()
+
+ def bbox_loss(self, input_bbox, target_bbox):
+ return F.l1_loss(input_bbox, target_bbox, reduction="mean")
+
+ def objectness_loss(self, input_objectness, target_objectness):
+ batch_size, num_objects = target_objectness.shape
+ return F.cross_entropy(
+ input_objectness.view(batch_size * num_objects, -1),
+ target_objectness.view(batch_size * num_objects),
+ )
+
+ def clip_target_to_instances(self, clip_target_cls, clip_target_bbox, clip_target_id):
+ """Converting the target in one clip into instances.
+
+ Args:
+ clip_target_cls (_type_): object class, -1 is background.
+ clip_target_bbox (_type_): object bounding box.
+ clip_target_id (_type_): object id.
+ """
+ num_frames = clip_target_bbox.shape[0]
+ clip_gt_instances = []
+ for fidx in range(num_frames):
+ frame_gt_instances = base.Instances((1, 1))
+ frame_gt_cls = clip_target_cls[fidx]
+ non_empty_mask_idx = frame_gt_cls > -1
+ if non_empty_mask_idx.sum() > 0:
+ frame_gt_instances.boxes = clip_target_bbox[fidx][non_empty_mask_idx]
+ frame_gt_instances.labels = frame_gt_cls[non_empty_mask_idx]
+ frame_gt_instances.obj_ids = clip_target_id[fidx][non_empty_mask_idx]
+ clip_gt_instances.append(frame_gt_instances)
+ return clip_gt_instances
+
+ def _generate_empty_tracks(self, num_queries, device):
+ track_instances = Instances((1, 1))
+
+ # At init, the number of track_instances is the same as slot number
+ track_instances.obj_idxes = torch.full((num_queries,), -1, dtype=torch.long, device=device)
+ track_instances.matched_gt_idxes = torch.full(
+ (num_queries,), -1, dtype=torch.long, device=device
+ )
+ track_instances.disappear_time = torch.zeros((num_queries,), dtype=torch.long, device=device)
+ track_instances.iou = torch.zeros((num_queries,), dtype=torch.float, device=device)
+ track_instances.scores = torch.zeros((num_queries,), dtype=torch.float, device=device)
+ track_instances.track_scores = torch.zeros((num_queries,), dtype=torch.float, device=device)
+ track_instances.pred_boxes = torch.zeros((num_queries, 4), dtype=torch.float, device=device)
+ track_instances.pred_logits = torch.zeros(
+ (num_queries, self.num_classes), dtype=torch.float, device=device
+ )
+
+ return track_instances.to(device)
+
+ def _get_src_permutation_idx(self, indices):
+ # permute predictions following indices
+ batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
+ src_idx = torch.cat([src for (src, _) in indices])
+ return batch_idx, src_idx
+
+ def _get_tgt_permutation_idx(self, indices):
+ # permute targets following indices
+ batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
+ tgt_idx = torch.cat([tgt for (_, tgt) in indices])
+ return batch_idx, tgt_idx
+
+ def loss_boxes(self, outputs, gt_instances: List[Instances], indices: List[tuple], num_boxes):
+ """Compute the losses related to the bounding boxes.
+
+ targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
+ The target boxes are expected in format (center_x, center_y, h, w),
+ normalized by the image size.
+ """
+ # We ignore the regression loss of the track-disappear slots.
+ # TODO: Make this filter process more elegant.
+ filtered_idx = []
+ for src_per_img, tgt_per_img in indices:
+ keep = tgt_per_img != -1
+ filtered_idx.append((src_per_img[keep], tgt_per_img[keep]))
+ indices = filtered_idx
+ idx = self._get_src_permutation_idx(indices)
+ src_boxes = outputs["pred_boxes"][idx]
+ target_boxes = torch.cat(
+ [gt_per_img.boxes[i] for gt_per_img, (_, i) in zip(gt_instances, indices)], dim=0
+ )
+
+ # for pad target, don't calculate regression loss, judged by whether obj_id=-1
+ target_obj_ids = torch.cat(
+ [gt_per_img.obj_ids[i] for gt_per_img, (_, i) in zip(gt_instances, indices)], dim=0
+ ) # size(16)
+ mask = target_obj_ids != -1
+
+ # only use l1 loss for now, will consider giou loss later
+ loss_bbox = F.l1_loss(src_boxes[mask], target_boxes[mask], reduction="none")
+ loss_giou = 1 - torch.diag(
+ generalized_box_iou(
+ box_cxcywh_to_xyxy(src_boxes[mask]), box_cxcywh_to_xyxy(target_boxes[mask])
+ )
+ )
+
+ # Note, we will not normalize by the num_boxes here. Will handle later.
+ return loss_bbox.sum() + loss_giou.sum()
+
+ def loss_labels(self, outputs, gt_instances: List[Instances], indices, num_boxes, log=False):
+ """Classification loss (NLL).
+
+ targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
+ """
+ src_logits = outputs["pred_logits"]
+ idx = self._get_src_permutation_idx(indices)
+ target_classes = torch.full(
+ src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
+ )
+ # The matched gt for disappear track query is set -1.
+ labels = []
+ for gt_per_img, (_, J) in zip(gt_instances, indices):
+ labels_per_img = torch.ones_like(J)
+ # set labels of track-appear slots to 0.
+ if len(gt_per_img) > 0:
+ labels_per_img[J != -1] = gt_per_img.labels[J[J != -1]]
+ labels.append(labels_per_img)
+ target_classes_o = torch.cat(labels)
+ target_classes[idx] = target_classes_o
+ # we use focal loss for each class
+ gt_labels_target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[
+ :, :, :-1
+ ] # no loss for the last (background) class
+ gt_labels_target = gt_labels_target.to(src_logits)
+ loss_ce = focal_loss(
+ src_logits.flatten(1),
+ gt_labels_target.flatten(1),
+ alpha=0.25,
+ gamma=2,
+ num_boxes=num_boxes,
+ mean_in_dim1=False,
+ )
+ loss_ce = loss_ce.sum()
+ return loss_ce
+
+ @RoutableMixin.route
+ def forward(
+ self,
+ input_bbox: torch.Tensor,
+ target_bbox: torch.Tensor,
+ input_cls: torch.Tensor,
+ target_cls: torch.Tensor,
+ target_id: torch.Tensor,
+ global_step: int,
+ ):
+ target_bbox = target_bbox.detach()
+ target_cls = target_cls.detach()
+ target_id = target_id.detach()
+
+ batch_size, num_frames, num_queries, _ = input_bbox.shape
+ device = input_bbox.device
+
+ total_loss = 0
+ num_samples = 0
+
+ # Iterate through each clip. Might think about if parallelable
+ for cidx in range(batch_size):
+ clip_target_instances = self.clip_target_to_instances(
+ target_cls[cidx], target_bbox[cidx], target_id[cidx]
+ )
+ # Init empty prediction tracks
+ track_instances = self._generate_empty_tracks(num_queries, device)
+ for fidx in range(num_frames):
+ gt_instances_i = clip_target_instances[fidx]
+ # put the prediction at the current frame into track instances.
+ track_scores = input_cls[cidx, fidx].max(dim=-1).values
+ track_instances.scores = track_scores
+ pred_logits_i = input_cls[cidx, fidx]
+ pred_boxes_i = input_bbox[cidx, fidx]
+ outputs_i = {
+ "pred_logits": pred_logits_i.unsqueeze(0),
+ "pred_boxes": pred_boxes_i.unsqueeze(0),
+ }
+ track_instances.pred_logits = pred_logits_i
+ track_instances.pred_boxes = pred_boxes_i
+
+ # step 0: collect existing matched pairs
+ obj_idxes = gt_instances_i.obj_ids
+ obj_idxes_list = obj_idxes.detach().cpu().numpy().tolist()
+ obj_idx_to_gt_idx = {
+ obj_idx: gt_idx for gt_idx, obj_idx in enumerate(obj_idxes_list)
+ }
+
+ # step1. inherit and update the previous tracks.
+ num_disappear_track = 0
+ for j in range(len(track_instances)):
+ obj_id = track_instances.obj_idxes[j].item()
+ # set new target idx.
+ if obj_id >= 0:
+ if obj_id in obj_idx_to_gt_idx:
+ track_instances.matched_gt_idxes[j] = obj_idx_to_gt_idx[obj_id]
+ else:
+ num_disappear_track += 1
+ track_instances.matched_gt_idxes[j] = -1 # track-disappear case.
+ else:
+ track_instances.matched_gt_idxes[j] = -1
+
+ full_track_idxes = torch.arange(len(track_instances), dtype=torch.long).to(
+ input_bbox.device
+ )
+ matched_track_idxes = track_instances.obj_idxes >= 0 # occu
+ prev_matched_indices = torch.stack(
+ [
+ full_track_idxes[matched_track_idxes],
+ track_instances.matched_gt_idxes[matched_track_idxes],
+ ],
+ dim=1,
+ ).to(input_bbox.device)
+
+ # step2. select the unmatched slots.
+ # note that the FP tracks whose obj_idxes are -2 will not be selected here.
+ unmatched_track_idxes = full_track_idxes[track_instances.obj_idxes == -1]
+
+ # step3. select the untracked gt instances (new tracks).
+ tgt_indexes = track_instances.matched_gt_idxes
+ tgt_indexes = tgt_indexes[tgt_indexes != -1]
+ tgt_state = torch.zeros(len(gt_instances_i)).to(pred_logits_i.device)
+ tgt_state[tgt_indexes] = 1
+ untracked_tgt_indexes = torch.arange(len(gt_instances_i)).to(pred_logits_i.device)[
+ tgt_state == 0
+ ]
+ untracked_gt_instances = {
+ "labels": gt_instances_i[untracked_tgt_indexes].labels,
+ "boxes": gt_instances_i[untracked_tgt_indexes].boxes,
+ }
+
+ def match_for_single_decoder_layer(
+ unmatched_outputs,
+ matcher,
+ untracked_gt_instances,
+ unmatched_track_idxes,
+ untracked_tgt_indexes,
+ device,
+ ):
+ costMatrix = CompDETRCostMatrix(
+ unmatched_outputs, [untracked_gt_instances]
+ ) # list[tuple(src_idx, tgt_idx)]
+
+ new_track_indices = []
+ for c in costMatrix:
+ AssignmentMatrix, _ = self.matcher(c)
+ assert AssignmentMatrix.shape[0] == 1, "Only match for one frame."
+ new_track_indices.append(torch.where(AssignmentMatrix[0] > 0))
+
+ assert len(new_track_indices) == 1, "Only match for one frame."
+ src_idx = new_track_indices[0][0]
+ tgt_idx = new_track_indices[0][1]
+ # concat src and tgt.
+ res_new_matched_indices = torch.stack(
+ [unmatched_track_idxes[src_idx], untracked_tgt_indexes[tgt_idx]], dim=1
+ ).to(device)
+ return res_new_matched_indices
+
+ # step4. do matching between the unmatched slots and GTs.
+ unmatched_outputs = {
+ "pred_logits": track_instances.pred_logits[unmatched_track_idxes].unsqueeze(0),
+ "pred_boxes": track_instances.pred_boxes[unmatched_track_idxes].unsqueeze(0),
+ }
+ if unmatched_outputs["pred_logits"].shape[1] == 0:
+ # NOTE, this is a hack when try to use random_strided_window
+ # Figure out how it really works.
+ new_matched_indices = (
+ torch.zeros([0, 2]).long().to(track_instances.pred_logits.device)
+ )
+ else:
+ new_matched_indices = match_for_single_decoder_layer(
+ unmatched_outputs,
+ self.matcher,
+ untracked_gt_instances,
+ unmatched_track_idxes,
+ untracked_tgt_indexes,
+ pred_logits_i.device,
+ )
+
+ # step5. update obj_idxes according to the new matching result.
+ track_instances.obj_idxes[new_matched_indices[:, 0]] = gt_instances_i.obj_ids[
+ new_matched_indices[:, 1]
+ ].long()
+ track_instances.matched_gt_idxes[new_matched_indices[:, 0]] = new_matched_indices[
+ :, 1
+ ]
+
+ # step6. merge the new pairs and the matched pairs.
+ matched_indices = torch.cat([new_matched_indices, prev_matched_indices], dim=0)
+
+ # step7. calculate losses.
+ num_samples += len(gt_instances_i) + num_disappear_track
+ cls_loss = self.loss_labels(
+ outputs=outputs_i,
+ gt_instances=[gt_instances_i],
+ indices=[(matched_indices[:, 0], matched_indices[:, 1])],
+ num_boxes=1,
+ )
+ bbox_loss = self.loss_boxes(
+ outputs=outputs_i,
+ gt_instances=[gt_instances_i],
+ indices=[(matched_indices[:, 0], matched_indices[:, 1])],
+ num_boxes=1,
+ )
+ total_loss += cls_loss + bbox_loss
+ # A naive normalization, might not be the best
+ total_loss /= num_samples
+ total_loss *= self.loss_weight
+ return total_loss
+
+
+class CLIPLoss(nn.Module, RoutableMixin):
+ def __init__(
+ self,
+ normalize_inputs: bool = True,
+ learn_scale: bool = True,
+ max_temperature: Optional[float] = None,
+ first_path: Optional[str] = None,
+ second_path: Optional[str] = None,
+ model_path: Optional[str] = path_defaults.MODEL,
+ ):
+ nn.Module.__init__(self)
+ RoutableMixin.__init__(
+ self, {"first_rep": first_path, "second_rep": second_path, "model": model_path}
+ )
+ self.normalize_inputs = normalize_inputs
+ if learn_scale:
+ self.logit_scale = nn.Parameter(torch.zeros([]) * log(1 / 0.07)) # Same init as CLIP.
+ else:
+ self.register_buffer("logit_scale", torch.zeros([])) # exp(0) = 1, i.e. no scaling.
+ self.max_temperature = max_temperature
+
+ @RoutableMixin.route
+ def forward(
+ self,
+ first_rep: torch.Tensor,
+ second_rep: torch.Tensor,
+ model: Optional[pl.LightningModule] = None,
+ ):
+ # Collect all representations.
+ if self.normalize_inputs:
+ first_rep = first_rep / first_rep.norm(dim=-1, keepdim=True)
+ second_rep = second_rep / second_rep.norm(dim=-1, keepdim=True)
+
+ temperature = self.logit_scale.exp()
+ if self.max_temperature:
+ temperature = torch.clamp_max(temperature, self.max_temperature)
+
+ if model is not None and hasattr(model, "trainer") and model.trainer.world_size > 1:
+ # Running on multiple GPUs.
+ global_rank = model.global_rank
+ all_first_rep, all_second_rep = model.all_gather(
+ [first_rep, second_rep], sync_grads=True
+ )
+ world_size, batch_size = all_first_rep.shape[:2]
+ labels = (
+ torch.arange(batch_size, dtype=torch.long, device=first_rep.device)
+ + batch_size * global_rank
+ )
+ # Flatten the GPU dim into batch.
+ all_first_rep = all_first_rep.flatten(0, 1)
+ all_second_rep = all_second_rep.flatten(0, 1)
+
+ # Compute inner product for instances on the current GPU.
+ logits_per_first = temperature * first_rep @ all_second_rep.t()
+ logits_per_second = temperature * second_rep @ all_first_rep.t()
+
+ # For visualization purposes, return the cosine similarities on the local batch.
+ similarities = (
+ 1
+ / temperature
+ * logits_per_first[:, batch_size * global_rank : batch_size * (global_rank + 1)]
+ )
+ # shape = [local_batch_size, global_batch_size]
+ else:
+ batch_size = first_rep.shape[0]
+ labels = torch.arange(batch_size, dtype=torch.long, device=first_rep.device)
+ # When running with only a single GPU we can save some compute time by reusing
+ # computations.
+ logits_per_first = temperature * first_rep @ second_rep.t()
+ logits_per_second = logits_per_first.t()
+ similarities = 1 / temperature * logits_per_first
+
+ return (
+ (F.cross_entropy(logits_per_first, labels) + F.cross_entropy(logits_per_second, labels))
+ / 2,
+ {"similarities": similarities, "temperature": temperature},
+ )
+
+
+def CompDETRSegCostMatrix(
+ predicts,
+ targets,
+):
+ """Compute cost matrix between outputs instances and target instances.
+
+ Returns:
+ costMatrix: A iter of tensors of size [num_outputs, num_targets].
+ """
+ # filter out valid targets
+ npr, h, w = predicts.shape
+ nt = targets.shape[0]
+
+ predicts = repeat(predicts, "npr h w -> (npr repeat) h w", repeat=nt)
+ targets = repeat(targets, "nt h w -> (repeat nt) h w", repeat=npr)
+
+ cost = F.binary_cross_entropy(predicts, targets.float(), reduction="none").mean(-1).mean(-1)
+ cost = rearrange(cost, "(npr nt) -> npr nt", npr=npr, nt=nt)
+ return cost
+
+
+class DETRSegLoss(nn.Module, RoutableMixin):
+ def __init__(
+ self,
+ loss_weight: float = 1.0,
+ ignore_background: bool = True,
+ foreground_weight: float = 1.0,
+ foreground_matching_weight: float = 1.0,
+ global_loss: bool = True,
+ input_mask_path: Optional[str] = None,
+ target_mask_path: Optional[str] = None,
+ foreground_logits_path: Optional[str] = None,
+ ):
+ nn.Module.__init__(self)
+ RoutableMixin.__init__(
+ self,
+ {
+ "input_mask": input_mask_path,
+ "target_mask": target_mask_path,
+ "foreground_logits": foreground_logits_path,
+ "model": path_defaults.MODEL,
+ },
+ )
+ self.loss_weight = loss_weight
+ self.ignore_background = ignore_background
+ self.foreground_weight = foreground_weight
+ self.foreground_matching_weight = foreground_matching_weight
+ self.global_loss = global_loss
+ self.matcher = CPUHungarianMatcher()
+
+ @RoutableMixin.route
+ def forward(
+ self,
+ input_mask: torch.Tensor,
+ target_mask: torch.Tensor,
+ foreground_logits: Optional[torch.Tensor] = None,
+ model: Optional[pl.LightningModule] = None,
+ ):
+ target_mask = target_mask.detach() > 0
+ device = target_mask.device
+
+ # A nan mask is not considered.
+ valid_targets = ~(target_mask.isnan().all(-1).all(-1)).any(-1)
+ # Discard first dimension mask as it is background.
+ if self.ignore_background:
+ # Assume first class in masks is background.
+ if len(target_mask.shape) > 4: # Video data (bs, frame, classes, w, h).
+ target_mask = target_mask[:, :, 1:]
+ else: # Image data (bs, classes, w, h).
+ target_mask = target_mask[:, 1:]
+
+ targets = target_mask[valid_targets]
+ predictions = input_mask[valid_targets]
+ if foreground_logits is not None:
+ foreground_logits = foreground_logits[valid_targets]
+
+ total_loss = torch.tensor(0.0, device=device)
+ num_samples = 0
+
+ # Iterate through each clip. Might think about if parallelable
+ for i, (prediction, target) in enumerate(zip(predictions, targets)):
+ # Filter empty masks.
+ target = target[target.sum(-1).sum(-1) > 0]
+
+ # Compute matching.
+ costMatrixSeg = CompDETRSegCostMatrix(
+ prediction,
+ target,
+ )
+ # We cannot rely on the matched cost for computing the loss due to
+ # normalization issues between segmentation component (normalized by
+ # number of matches) and classification component (normalized by
+ # number of predictions). Thus compute both components separately
+ # after deriving the matching matrix.
+ if foreground_logits is not None and self.foreground_matching_weight != 0.0:
+ # Positive classification component.
+ logits = foreground_logits[i]
+ costMatrixTotal = (
+ costMatrixSeg
+ + self.foreground_weight
+ * F.binary_cross_entropy_with_logits(
+ logits, torch.ones_like(logits), reduction="none"
+ ).detach()
+ )
+ else:
+ costMatrixTotal = costMatrixSeg
+
+ # Matcher takes a batch but we are doing this one by one.
+ matching_matrix = self.matcher(costMatrixTotal.unsqueeze(0))[0].squeeze(0)
+ n_matches = min(predictions.shape[0], target.shape[0])
+ if n_matches > 0:
+ instance_cost = (costMatrixSeg * matching_matrix).sum(-1).sum(-1) / n_matches
+ else:
+ instance_cost = torch.tensor(0.0, device=device)
+
+ if foreground_logits is not None:
+ ismatched = (matching_matrix > 0).any(-1)
+ logits = foreground_logits[i].squeeze(-1)
+ instance_cost += self.foreground_weight * F.binary_cross_entropy_with_logits(
+ logits, ismatched.float(), reduction="mean"
+ )
+
+ total_loss += instance_cost
+ # Normalize by number of matches.
+ num_samples += 1
+
+ if (
+ model is not None
+ and hasattr(model, "trainer")
+ and model.trainer.world_size > 1
+ and self.global_loss
+ ):
+ # As data is sparsely labeled return the average loss over all GPUs.
+ # This should make the loss a mit more smooth.
+ all_losses, sample_counts = model.all_gather([total_loss, num_samples], sync_grads=True)
+ total_count = sample_counts.sum()
+ if total_count > 0:
+ total_loss = all_losses.sum() / total_count
+ else:
+ total_loss = torch.tensor(0.0, device=device)
+
+ return total_loss * self.loss_weight
+ else:
+ if num_samples == 0:
+ # Avoid division by zero if a batch does not contain any labels.
+ return torch.tensor(0.0, device=targets.device)
+
+ total_loss /= num_samples
+ total_loss *= self.loss_weight
+ return total_loss
+
+
+class EM_rec_loss(nn.Module, RoutableMixin):
+ def __init__(
+ self,
+ loss_weight: float = 20,
+ attn_path: Optional[str] = None,
+ rec_path: Optional[str] = None,
+ tgt_path: Optional[str] = None,
+ img_path: Optional[str] = None,
+ tgt_vis_path: Optional[str] = None,
+ weights_path: Optional[str] = None,
+ attn_index_path: Optional[str] = None,
+ slot_path: Optional[str] = None,
+ pred_feat_path: Optional[str] = None,
+ ):
+ nn.Module.__init__(self)
+ RoutableMixin.__init__(
+ self,
+ {
+ "segmentations": attn_path,
+ "reconstructions": rec_path,
+ "masks": tgt_path,
+ "masks_vis": tgt_vis_path,
+ "rec_tgt": img_path,
+ "weights": weights_path,
+ "attn_index": attn_index_path,
+ "slots": slot_path,
+ "pred_slots": pred_feat_path,
+ },
+ )
+ self.loss_weight = loss_weight
+ self.loss_fn = lambda x1, x2: nn.functional.mse_loss(x1, x2, reduction="none")
+
+ @RoutableMixin.route
+ def forward(
+ self,
+ segmentations: torch.Tensor, # rollout_decode.masks
+ masks: torch.Tensor, # decoder.masks
+ reconstructions: torch.Tensor,
+ rec_tgt: torch.Tensor,
+ masks_vis: torch.Tensor,
+ attn_index: torch.Tensor,
+ slots: torch.Tensor,
+ pred_slots: torch.Tensor,
+ smooth=1,
+ ):
+ b, f, c, h, w = segmentations.shape
+ _, _, n_slots, n_buffer = attn_index.shape
+
+ segmentations = (
+ segmentations.reshape(-1, n_buffer, h, w).unsqueeze(1).repeat(1, n_slots, 1, 1, 1)
+ )
+ masks = masks.reshape(-1, n_slots, h, w).unsqueeze(2).repeat(1, 1, n_buffer, 1, 1)
+ masks = masks > 0.5
+ masks_vis = (
+ masks_vis.reshape(-1, n_slots, h, w)
+ .unsqueeze(2)
+ .unsqueeze(3)
+ .repeat(1, 1, n_buffer, 3, 1, 1)
+ )
+ masks_vis = masks_vis > 0.5
+ attn_index = attn_index.reshape(-1, n_slots, n_buffer)
+ rec_tgt = (
+ rec_tgt.reshape(-1, 3, h, w)
+ .unsqueeze(1)
+ .unsqueeze(2)
+ .repeat(1, n_slots, n_buffer, 1, 1, 1)
+ )
+ reconstructions = (
+ reconstructions.reshape(-1, n_buffer, 3, h, w)
+ .unsqueeze(1)
+ .repeat(1, n_slots, 1, 1, 1, 1)
+ )
+ rec_pred = reconstructions * masks_vis
+ rec_tgt_ = rec_tgt * masks_vis
+ loss = torch.sum(
+ F.binary_cross_entropy(segmentations, masks.float(), reduction="none"), (-1, -2)
+ ) / (h * w) + 0.1 * torch.sum(self.loss_fn(rec_pred, rec_tgt_), (-3, -2, -1))
+ total_loss = torch.sum(attn_index * loss, (0, 1, 2)) / (b * f * n_slots * n_buffer)
+ return (total_loss) * self.loss_weight
diff --git a/ocl/matching.py b/ocl/matching.py
new file mode 100644
index 0000000..b654f39
--- /dev/null
+++ b/ocl/matching.py
@@ -0,0 +1,34 @@
+"""Methods for matching between sets of elements."""
+from typing import Tuple, Type
+
+import numpy as np
+import torch
+from scipy.optimize import linear_sum_assignment
+from torchtyping import TensorType
+
+# Avoid errors due to flake:
+batch_size = None
+n_elements = None
+
+CostMatrix = Type[TensorType["batch_size", "n_elements", "n_elements"]]
+AssignmentMatrix = Type[TensorType["batch_size", "n_elements", "n_elements"]]
+CostVector = Type[TensorType["batch_size"]]
+
+
+class Matcher(torch.nn.Module):
+ """Matcher base class to define consistent interface."""
+
+ def forward(self, C: CostMatrix) -> Tuple[AssignmentMatrix, CostVector]:
+ pass
+
+
+class CPUHungarianMatcher(Matcher):
+ """Implementaiton of a cpu hungarian matcher using scipy.optimize.linear_sum_assignment."""
+
+ def forward(self, C: CostMatrix) -> Tuple[AssignmentMatrix, CostVector]:
+ X = torch.zeros_like(C)
+ C_cpu: np.ndarray = C.detach().cpu().numpy()
+ for i, cost_matrix in enumerate(C_cpu):
+ row_ind, col_ind = linear_sum_assignment(cost_matrix)
+ X[i][row_ind, col_ind] = 1.0
+ return X, (C * X).sum(dim=(1, 2))
diff --git a/ocl/memory.py b/ocl/memory.py
new file mode 100644
index 0000000..5bc5b95
--- /dev/null
+++ b/ocl/memory.py
@@ -0,0 +1,342 @@
+import dataclasses
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from ocl import path_defaults
+from ocl.memory_rollout import GPT
+from ocl.mha import MultiHeadAttention, MultiHeadAttention_for_index
+from ocl.utils.routing import RoutableMixin
+
+
+@dataclasses.dataclass
+class MemoryOutput:
+ # rollout: TensorType["batch_size", "n_objects", "object_dim"] # noqa: F821
+ # idx_mask: TensorType["batch_size", "n_objects"] # noqa: F821
+ # matched_idx: dict # noqa: F821
+ # object_features: TensorType["batch_size", "n_objects", "n_spatial_features"] # noqa: F821
+ rollout: torch.Tensor # noqa: F821
+ object_features: torch.Tensor # noqa: F821
+ mem: torch.Tensor
+ eval_mem_features: torch.Tensor
+ table: torch.Tensor
+ attn_index: torch.Tensor
+
+
+class SelfSupervisedMemory(nn.Module, RoutableMixin):
+ def __init__(
+ self,
+ embed_dim: int = 128,
+ num_objects: int = 20,
+ memory_len: int = 30,
+ mlp_size: int = 512,
+ mlp_layer: int = 3,
+ stale_number: int = 5,
+ appearance_threshold: float = 0.2,
+ dropout_rate: float = 0.1,
+ object_features_path: Optional[str] = path_defaults.OBJECTS,
+ conditioning_path: Optional[str] = path_defaults.CONDITIONING,
+ attention_maps_path: Optional[str] = None,
+ frame_features_path: Optional[str] = path_defaults.FEATURES,
+ first_box_path: Optional[str] = None,
+ init_flag_path: Optional[str] = None,
+ matched_idx_path: Optional[str] = None,
+ ):
+ nn.Module.__init__(self)
+ RoutableMixin.__init__(
+ self,
+ {
+ "observation": object_features_path,
+ "conditions": conditioning_path,
+ "attention_maps": attention_maps_path,
+ "frame_features": frame_features_path,
+ "first_frame_boxes": first_box_path,
+ "init_flag": init_flag_path,
+ "matched_index": matched_idx_path,
+ },
+ )
+ self.embed_dim = embed_dim
+ self.memory_len = memory_len
+ self.object_num = num_objects
+ self.stale_number = stale_number
+ self.threshold = appearance_threshold
+ self.num_heads = 4
+ self.roll_out_module = GPT(buffer_len=memory_len, n_layer=8, n_head=8, n_embd=embed_dim)
+ self.register_buffer("memory", torch.zeros(8, memory_len, num_objects, embed_dim))
+ self.register_buffer("memory_table", torch.zeros(8, num_objects))
+ self.register_buffer("stale_counter", torch.zeros(8, num_objects))
+ self.MultiHead_1 = MultiHeadAttention_for_index(
+ n_head=4, d_model=embed_dim, d_k=embed_dim, d_v=embed_dim
+ ) # n_head=1
+ self.MultiHead_2 = MultiHeadAttention(
+ n_head=4, d_model=embed_dim, d_k=embed_dim, d_v=embed_dim
+ )
+
+ def remove_duplicated_slot_id(self, slot_masks):
+ slot_masks = slot_masks > 0.5
+ n, h, w = slot_masks.shape
+ # remove empty slots
+ mask_sum = torch.sum(slot_masks.reshape(-1, h * w), dim=-1)
+ # In SAVi, we give the first slot background init.
+ # BUG: this is not always the case.
+ bg_value = mask_sum[0]
+ bg_idx = (mask_sum == bg_value).nonzero(as_tuple=True)[0]
+ empty_idx = (mask_sum <= 10).nonzero(as_tuple=True)[0]
+
+ # remove duplicated masks
+ mask = slot_masks.unsqueeze(1).to(torch.bool).reshape(n, 1, -1)
+ mask_ = slot_masks.unsqueeze(0).to(torch.bool).reshape(1, n, -1)
+ intersection = torch.sum(mask & mask_, dim=-1).to(torch.float64)
+ union = torch.sum(mask | mask_, dim=-1).to(torch.float64)
+ pairwise_iou = intersection / union
+ pairwise_iou[union == 0] = 1.0
+ dup_idx = []
+ for i in range(n):
+ for j in range(i + 1, n):
+ if pairwise_iou[i, j] > 0.5:
+ dup_idx.append(i)
+ invalid_idx = [*set(list(bg_idx) + list(empty_idx) + list(dup_idx))]
+ valid_idx = []
+ for i in range(n):
+ if i not in invalid_idx:
+ valid_idx.append(i)
+ if len(empty_idx) == n:
+ valid_idx.append(0)
+
+ return valid_idx
+
+ def search_bg_id(self, slot_masks):
+ slot_masks = slot_masks > 0.5
+ n, h, w = slot_masks.shape
+ mask_sum = torch.sum(slot_masks.reshape(-1, h * w), dim=-1)
+ idx = torch.argmin(mask_sum)
+ # print(idx, mask_sum)
+ return idx
+
+ def initialization(self, box, conditions, cur_slots, cur_slot_masks, prev_slot_masks, frame_id):
+ if frame_id == 0:
+ # For each video, we should initialize the register buffers as zero
+ bs = conditions.shape[0]
+ memory_shape = (bs, self.memory_len, self.object_num, self.embed_dim)
+ memory_table_shape = (bs, self.object_num)
+ stale_counter_shape = (bs, self.object_num)
+ self.memory = torch.zeros(memory_shape).to(conditions.device)
+ self.memory_table = torch.zeros(memory_table_shape).to(conditions.device)
+ self.stale_counter = torch.zeros(stale_counter_shape).to(conditions.device)
+ for b in range(bs):
+ valid_idx = self.remove_duplicated_slot_id(cur_slot_masks[b])
+ num_obj = len(valid_idx)
+ # bg
+ self.memory[b, 0, 0, :] = conditions[b, 0, :]
+ # non duplicated objects
+ self.memory[b, 0, 1 : num_obj + 1, :] = conditions[b, valid_idx, :]
+ self.memory_table[b, : num_obj + 1] += 1
+ else:
+ """IoU score to find new objects"""
+ bs, n, h, w = prev_slot_masks.shape
+ for b in range(bs):
+ # self.memory_eval[b, frame_id, -1, :] = ori_slots[b, 0, :]
+ cur_valid_idx = self.remove_duplicated_slot_id(cur_slot_masks[b])
+ pre_valid_idx = self.remove_duplicated_slot_id(prev_slot_masks[b])
+
+ cur_slot_mask = cur_slot_masks[b][cur_valid_idx] > 0.5
+ prev_slot_mask = prev_slot_masks[b][pre_valid_idx] > 0.5
+
+ # calculate pairwise iou
+ cur_mask = (
+ cur_slot_mask.unsqueeze(1).to(torch.bool).reshape(len(cur_valid_idx), 1, -1)
+ )
+ prev_mask = (
+ prev_slot_mask.unsqueeze(0).to(torch.bool).reshape(1, len(pre_valid_idx), -1)
+ )
+ intersection = torch.sum(cur_mask & prev_mask, dim=-1).to(torch.float64)
+ # union = torch.sum(cur_mask | prev_mask, dim=-1).to(torch.float64)
+ # pairwise_iou = intersection / union
+ # Remove NaN from divide-by-zero: class does not occur, and class was not predicted.
+ # pairwise_iou[union == 0] = 1.0
+ sim, _ = torch.max(intersection, dim=-1)
+ # NOTE: now using absolute value to determine new object. This might not be optimal.
+ # Can have a check with IOU tracker to see their In-logic.
+ new_obj_idx = list((sim < 10).nonzero(as_tuple=True)[0])
+
+ new_obj_idx_ori = [cur_valid_idx[id] for id in new_obj_idx]
+ num_new_obj = len(new_obj_idx_ori)
+
+ new_mem_idx = list((self.memory_table[b] == 0).nonzero(as_tuple=True)[0])
+ old_mem_idx = list((self.memory_table[b] != 0).nonzero(as_tuple=True)[0])
+ if num_new_obj > 0 and len(new_mem_idx) > 0:
+ last_pos = old_mem_idx[-1] + 1
+ if last_pos + num_new_obj - 1 in new_mem_idx:
+ self.memory[b, 0, last_pos : last_pos + num_new_obj] = cur_slots[
+ b, new_obj_idx_ori
+ ]
+ self.memory_table[b, last_pos : last_pos + num_new_obj] += 1
+
+ def soft_update(self, observations, predictions):
+ inputs = torch.cat((observations, predictions), -1)
+ alpha = self.amodal_prediction(inputs)
+ outputs = alpha * observations + (1 - alpha) * predictions
+ return F.normalize(outputs, dim=-1)
+
+ def buffer_terminate(self):
+ bs = self.stale_counter.shape[0]
+ for b in range(bs):
+ terminate_idx = (self.stale_counter >= self.stale_number).nonzero(as_tuple=True)[0]
+ num_dead_buffer = len(list(terminate_idx))
+ tmp = torch.zeros((self.memory_len, num_dead_buffer, self.embed_dim)).to(
+ self.memory.device
+ )
+ self.memory[b, :, terminate_idx] = tmp
+ self.stale_counter[b, terminate_idx] = 0
+
+ def sms_attn_index_only(self, observations, predictions):
+ # implement for multi-head-attention
+ b, h, w = observations.shape
+
+ attn_o_to_p, attn_o_to_p_weights = self.MultiHead_1(
+ F.normalize(observations, dim=-1),
+ F.normalize(predictions, dim=-1),
+ F.normalize(predictions, dim=-1),
+ )
+ sim = attn_o_to_p_weights
+ mask = torch.zeros(sim.shape).to(sim.device)
+ b, h, w = mask.shape
+ for i in range(b):
+ for j in range(w):
+ index = torch.argmax(sim[i, :, j])
+ mask[i, index, j] = 1
+ b, h, w = mask.shape
+ mask = sim + (mask - sim).detach()
+ mask = mask.transpose(1, 2)
+ object_feature = torch.einsum("bcn,bnk->bck", [mask, observations])
+ # momentum update memory
+ alpha = 1
+ # print(alpha)
+ mem_features = predictions * alpha + object_feature * (1 - alpha)
+ mem_features = F.normalize(mem_features, dim=-1)
+
+ return object_feature, mem_features, attn_o_to_p_weights
+
+ def sms_attn(self, observations, predictions, eval_flag):
+ attn_o_to_p, attn_o_to_p_weights = self.MultiHead_1(observations, predictions, predictions)
+
+ mask = torch.zeros(attn_o_to_p_weights.shape).to(attn_o_to_p_weights.device)
+ b, w, h = mask.shape
+ for i in range(b):
+ for j in range(w):
+ index = torch.argmax(attn_o_to_p_weights[i, j, :])
+ mask[i, j, index] = 1
+ # mask = attn_o_to_p_weights + (mask - attn_o_to_p_weights).detach()
+
+ # attn_o_to_p_weights_gumbel = F.gumbel_softmax(attn_o_to_p_weights, tau=1, hard=True)
+ weights = mask.clone()
+
+ # MultiHead_2 layer
+ if not eval_flag:
+ attn_o_to_p_weights_trans = torch.transpose(attn_o_to_p_weights, 1, 2)
+ else:
+ attn_o_to_p_weights_trans = torch.transpose(weights, 1, 2)
+
+ attn_p_to_o, attn_p_to_o_weights = self.MultiHead_2(
+ predictions, observations, observations, mask=attn_o_to_p_weights_trans
+ )
+
+ # replace the attn_p_to_o with predictions if the buffer is not assigned
+ if eval_flag:
+ b, h, w = weights.shape
+ weights_new = torch.zeros((b, h + 1, w)).to(attn_o_to_p.device) # [b, n+1, m]
+ weights_new[:, 0:h, :] = weights
+ weights_new[:, h, :] = torch.sum(weights, dim=1)
+ weights_new_convert_zero = weights_new[:, h, :].clone()
+ weights_new_convert_zero[weights_new[:, h, :] == 0] = 1
+ weights_new_convert_zero[weights_new[:, h, :] > 0] = 0
+ weights_new[:, h, :] = weights_new_convert_zero
+ b_p, h_p, w_p = attn_p_to_o.shape # merged features
+ for j in range(b_p):
+ index = weights_new[j, h, :].nonzero(as_tuple=True)[0] # hard index
+ if len(index) > 0:
+ # update the buffer that no slots matched with zero embeddings
+ attn_p_to_o[j][index] = torch.zeros((len(index), self.embed_dim)).to(
+ observations.device
+ )
+ # attn_p_to_o[j][index] = predictions[j][index].clone()
+ else:
+ b_p, h_p, w_p = attn_p_to_o.shape # merged features
+ for j in range(b_p):
+ # index = weights_new[j, h, :].nonzero(as_tuple=True)[0] # hard index
+ index = (self.memory_table[j] == 0).nonzero(as_tuple=True)[0]
+ if len(index) > 0:
+ # update the buffer that no slots matched with zero embeddings
+ attn_p_to_o[j][index] = torch.zeros((len(index), self.embed_dim)).to(
+ observations.device
+ )
+ return attn_p_to_o, weights, attn_o_to_p_weights
+
+ def sms_attn_merge(self, observations, predictions):
+ # MultiHead_2 layer
+ attn_p_to_o, attn_p_to_o_weights = self.MultiHead_2(predictions, observations, observations)
+ attn_p_to_o = self.merge_mlp(attn_p_to_o)
+ # replace the attn_p_to_o with predictions if the buffer is not assigned
+ b_p, h_p, w_p = attn_p_to_o.shape
+ return attn_p_to_o, 0
+
+ def update_sms(self, object_features):
+ # implement for c2 sms memory update
+
+ object_features_ = object_features.clone().detach()
+
+ # update memory
+ for b in range(object_features_.shape[0]):
+ for i in range(object_features_.shape[1]):
+ tmp = torch.sum(object_features_[b, i, :], dim=-1)
+ if tmp != 0 and self.memory_table[b, i] != 0:
+ # if self.memory_table[b, i] != 0:
+ pos = self.memory_table[b, i].cpu().numpy().astype(int)
+ self.memory[b, pos, i] = object_features_[b, i]
+ # self.memory_eval[b, pos, i] = object_features_[b, i]
+ self.memory_table[b, i] += 1
+ else:
+ self.stale_counter[b, i] += 1
+
+ return object_features
+
+ @RoutableMixin.route
+ def forward(
+ self,
+ box: torch.Tensor,
+ observations: torch.Tensor,
+ prev_slot_masks: torch.Tensor,
+ cur_slot_masks: torch.Tensor,
+ conditions: torch.Tensor,
+ frame_id: int,
+ ):
+ eval = not self.training
+ self.initialization(box, conditions, observations, cur_slot_masks, prev_slot_masks, frame_id)
+ if frame_id == 0:
+ predictions = self.memory[:, 0].clone()
+ object_features = predictions.clone()
+ b, n_slots = observations.shape[:2]
+ n_buffer = predictions.shape[1]
+ attn_index = torch.zeros((b, n_slots, n_buffer)).to(observations.device)
+ else:
+ # memory roll out
+ predictions = self.roll_out_module(self.memory, self.memory_table)
+ object_features, weights, attn_index = self.sms_attn(
+ observations, predictions, eval_flag=eval
+ )
+ # memory update
+ _ = self.update_sms(object_features)
+
+ # memory terminate
+ # NOTE No termination?
+ # self.buffer_terminate()
+ return MemoryOutput(
+ rollout=predictions,
+ object_features=object_features,
+ mem=self.memory,
+ eval_mem_features=object_features,
+ table=self.memory_table,
+ attn_index=attn_index,
+ )
diff --git a/ocl/memory_rollout.py b/ocl/memory_rollout.py
new file mode 100644
index 0000000..4169423
--- /dev/null
+++ b/ocl/memory_rollout.py
@@ -0,0 +1,151 @@
+"""Memory roll-out module, following GPT-2 architecture.
+
+References:
+1) minGPT by Andrej Karpathy:
+https://github.com/karpathy/minGPT/tree/master/mingpt
+2) the official GPT-2 TensorFlow implementation released by OpenAI:
+https://github.com/openai/gpt-2/blob/master/src/model.py
+"""
+
+import math
+
+import torch
+from torch import nn
+
+# -----------------------------------------------------------------------------
+
+
+class GELU(nn.Module):
+ def forward(self, x):
+ return (
+ 0.5
+ * x
+ * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
+ )
+
+
+class Block(nn.Module):
+ """One GPT-2 decoder block, consists of a Masked Self-Attn and a FFN."""
+
+ def __init__(self, n_embd, n_heads, dropout_rate):
+ super().__init__()
+ self.ln_1 = nn.LayerNorm(n_embd)
+ self.attn = nn.MultiheadAttention(n_embd, n_heads, batch_first=True)
+ self.ln_2 = nn.LayerNorm(n_embd)
+ self.mlp = nn.ModuleDict(
+ dict(
+ c_fc=nn.Linear(n_embd, 4 * n_embd),
+ c_proj=nn.Linear(4 * n_embd, n_embd),
+ act=GELU(),
+ dropout=nn.Dropout(dropout_rate),
+ )
+ )
+ m = self.mlp
+ self.ffn = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x))))
+
+ def forward(self, x, causal_mask):
+ att, att_weights = self.attn(
+ query=self.ln_1(x), key=self.ln_1(x), value=self.ln_1(x), attn_mask=causal_mask
+ )
+
+ x = x + att
+ x = x + self.ffn(self.ln_2(x))
+ return x, att_weights
+
+
+class GPT(nn.Module):
+ """Memory roll-out GPT."""
+
+ def __init__(
+ self, buffer_len, n_layer, n_head, n_embd, embd_pdrop=0.0, resid_pdrop=0.0, attn_pdrop=0.0
+ ):
+ super().__init__()
+ self.buffer_len = buffer_len
+ self.n_layer = n_layer
+ self.n_head = n_head
+ self.n_embd = n_embd
+ self.embd_pdrop = embd_pdrop
+ self.resid_pdrop = resid_pdrop
+ self.attn_pdrop = attn_pdrop
+
+ self.transformer = nn.ModuleDict(
+ dict(
+ wte=nn.Linear(self.n_embd, self.n_embd, bias=False),
+ wpe=nn.Embedding(self.buffer_len, self.n_embd),
+ drop=nn.Dropout(self.embd_pdrop),
+ h=nn.ModuleList(
+ [Block(self.n_embd, self.n_head, self.resid_pdrop) for _ in range(self.n_layer)]
+ ),
+ ln_f=nn.LayerNorm(self.n_embd),
+ )
+ )
+ # roll out to the same dimension
+ self.roll_out_head = nn.Linear(self.n_embd, self.n_embd, bias=False)
+
+ # init all weights
+ self.apply(self._init_weights)
+ for pn, p in self.named_parameters():
+ if pn.endswith("c_proj.weight"):
+ torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * self.n_layer))
+
+ # report number of parameters (note we don't count the decoder parameters in lm_head)
+ n_params = sum(p.numel() for p in self.transformer.parameters())
+ print("number of parameters: %.2fM" % (n_params / 1e6,))
+
+ def _init_weights(self, module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
+ if module.bias is not None:
+ torch.nn.init.zeros_(module.bias)
+ elif isinstance(module, nn.Embedding):
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
+ elif isinstance(module, nn.LayerNorm):
+ torch.nn.init.zeros_(module.bias)
+ torch.nn.init.ones_(module.weight)
+
+ def forward(self, mem, mem_table, targets=None):
+ device = mem.device
+ b, t, n, d = mem.shape
+
+ # reshape to merge the batch and num_buffer dimensionsni
+ mem = mem.permute(0, 2, 1, 3).reshape(b * n, t, d)
+ mem_table = mem_table.view(b * n, -1)
+ pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0)
+
+ tok_emb = self.transformer.wte(mem) # token embeddings of shape (b, t, n_embd)
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
+ x = self.transformer.drop(tok_emb + pos_emb)
+
+ # create causal attention masks
+ # need to check correctness
+ causal_masks = []
+ for idx in range(b * n):
+ occupied_len = mem_table[idx].cpu().numpy().astype(int)[0]
+ if occupied_len == 0:
+ occupied_len = 1
+ # causal_mask = torch.tril(torch.ones(self.buffer_len, self.buffer_len).to(device)).view(
+ # 1, self.buffer_len, self.buffer_len
+ # )
+ causal_mask = (
+ torch.zeros(self.buffer_len, self.buffer_len)
+ .to(device)
+ .view(1, self.buffer_len, self.buffer_len)
+ )
+ causal_mask[:, occupied_len:, occupied_len:] = 1
+ causal_mask = causal_mask > 0
+ causal_masks.append(causal_mask)
+ causal_masks = torch.stack(causal_masks)
+ causal_masks = causal_masks.repeat(1, self.n_head, 1, 1).view(-1, t, t)
+
+ for block in self.transformer.h:
+ x, attn_weights = block(x, causal_masks)
+ x = self.transformer.ln_f(x)
+ x = self.roll_out_head(x) # [b*n, t, d]
+
+ out = torch.zeros((b * n, d)).to(device)
+
+ for idx in range(b * n):
+ t_pos = mem_table[idx].cpu().numpy().astype(int)[0]
+ if t_pos > 0 and t_pos < t:
+ out[idx] = x[idx, t_pos - 1]
+ return out.view(b, n, d)
diff --git a/ocl/metrics.py b/ocl/metrics.py
new file mode 100644
index 0000000..d5e52ce
--- /dev/null
+++ b/ocl/metrics.py
@@ -0,0 +1,1374 @@
+import io
+from typing import Any, Dict, Optional
+
+import motmetrics as mm
+import numpy as np
+import pandas as pd
+import scipy.optimize
+import torch
+import torchmetrics
+import torchvision
+
+from ocl.utils.resizing import resize_patches_to_image
+from ocl.utils.routing import RoutableMixin
+
+mm.lap.default_solver = "lap"
+
+
+def tensor_to_one_hot(tensor: torch.Tensor, dim: int) -> torch.Tensor:
+ """Convert tensor to one-hot encoding by using maximum across dimension as one-hot element."""
+ assert 0 <= dim
+ max_idxs = torch.argmax(tensor, dim=dim, keepdim=True)
+ shape = [1] * dim + [-1] + [1] * (tensor.ndim - dim - 1)
+ one_hot = max_idxs == torch.arange(tensor.shape[dim], device=tensor.device).view(*shape)
+ return one_hot.to(torch.long)
+
+
+class TensorStatistic(torchmetrics.Metric, RoutableMixin):
+ """Metric that computes summary statistic of tensors for logging purposes.
+
+ First dimension of tensor is assumed to be batch dimension. Other dimensions are reduced to a
+ scalar by the chosen reduction approach (sum or mean).
+ """
+
+ def __init__(self, path: Optional[str], reduction: str = "mean"):
+ torchmetrics.Metric.__init__(self)
+ RoutableMixin.__init__(self, {"tensor": path})
+ if reduction not in ("sum", "mean"):
+ raise ValueError(f"Unknown reduction {reduction}")
+ self.reduction = reduction
+ self.add_state(
+ "values", default=torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx="sum"
+ )
+ self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
+
+ @RoutableMixin.route
+ def update(self, tensor: torch.Tensor):
+ tensor = torch.atleast_2d(tensor).flatten(1, -1).to(dtype=torch.float64)
+
+ if self.reduction == "mean":
+ tensor = torch.mean(tensor, dim=1)
+ elif self.reduction == "sum":
+ tensor = torch.sum(tensor, dim=1)
+
+ self.values += tensor.sum()
+ self.total += len(tensor)
+
+ def compute(self) -> torch.Tensor:
+ return self.values / self.total
+
+
+class TorchmetricsWrapper(torchmetrics.Metric, RoutableMixin):
+ """Wrapper for torchmetrics classes that works with routing."""
+
+ def __init__(
+ self,
+ metric: str,
+ prediction_path: str,
+ target_path: str,
+ metric_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ torchmetrics.Metric.__init__(self)
+ RoutableMixin.__init__(self, {"prediction": prediction_path, "target": target_path})
+ if not hasattr(torchmetrics, metric):
+ raise ValueError(f"Metric {metric} does not exist in torchmetrics")
+ self.metric = getattr(torchmetrics, metric)(**(metric_kwargs if metric_kwargs else {}))
+
+ @RoutableMixin.route
+ def update(self, prediction: torch.Tensor, target: torch.Tensor):
+ return self.metric.update(prediction, target)
+
+ def compute(self) -> torch.Tensor:
+ return self.metric.compute()
+
+
+class ARIMetric(torchmetrics.Metric, RoutableMixin):
+ """Computes ARI metric."""
+
+ def __init__(
+ self,
+ prediction_path: str,
+ target_path: str,
+ ignore_path: Optional[str] = None,
+ foreground: bool = True,
+ convert_target_one_hot: bool = False,
+ ignore_overlaps: bool = False,
+ ):
+ torchmetrics.Metric.__init__(self)
+ RoutableMixin.__init__(
+ self, {"prediction": prediction_path, "target": target_path, "ignore": ignore_path}
+ )
+ self.foreground = foreground
+ self.convert_target_one_hot = convert_target_one_hot
+ self.ignore_overlaps = ignore_overlaps
+ self.add_state(
+ "values", default=torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx="sum"
+ )
+ self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
+
+ @RoutableMixin.route
+ def update(
+ self, prediction: torch.Tensor, target: torch.Tensor, ignore: Optional[torch.Tensor] = None
+ ):
+ """Update this metric.
+
+ Args:
+ prediction: Predicted mask of shape (B, C, H, W) or (B, F, C, H, W), where C is the
+ number of classes.
+ target: Ground truth mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the
+ number of classes.
+ ignore: Ignore mask of shape (B, 1, H, W) or (B, 1, K, H, W)
+ """
+ if prediction.ndim == 5:
+ # Merge frames, height and width to single dimension.
+ prediction = prediction.transpose(1, 2).flatten(-3, -1)
+ target = target.transpose(1, 2).flatten(-3, -1)
+ if ignore is not None:
+ ignore = ignore.to(torch.bool).transpose(1, 2).flatten(-3, -1)
+ elif prediction.ndim == 4:
+ # Merge height and width to single dimension.
+ prediction = prediction.flatten(-2, -1)
+ target = target.flatten(-2, -1)
+ if ignore is not None:
+ ignore = ignore.to(torch.bool).flatten(-2, -1)
+ else:
+ raise ValueError(f"Incorrect input shape: f{prediction.shape}")
+
+ if self.ignore_overlaps:
+ overlaps = (target > 0).sum(1, keepdim=True) > 1
+ if ignore is None:
+ ignore = overlaps
+ else:
+ ignore = ignore | overlaps
+
+ if ignore is not None:
+ assert ignore.ndim == 3 and ignore.shape[1] == 1
+ prediction = prediction.clone()
+ prediction[ignore.expand_as(prediction)] = 0
+ target = target.clone()
+ target[ignore.expand_as(target)] = 0
+
+ # Make channels / gt labels the last dimension.
+ prediction = prediction.transpose(-2, -1)
+ target = target.transpose(-2, -1)
+
+ if self.convert_target_one_hot:
+ target_oh = tensor_to_one_hot(target, dim=2)
+ # For empty pixels (all values zero), one-hot assigns 1 to the first class, correct for
+ # this (then it is technically not one-hot anymore).
+ target_oh[:, :, 0][target.sum(dim=2) == 0] = 0
+ target = target_oh
+
+ # Should be either 0 (empty, padding) or 1 (single object).
+ assert torch.all(target.sum(dim=-1) < 2), "Issues with target format, mask non-exclusive"
+
+ if self.foreground:
+ ari = fg_adjusted_rand_index(prediction, target)
+ else:
+ ari = adjusted_rand_index(prediction, target)
+
+ self.values += ari.sum()
+ self.total += len(ari)
+
+ def compute(self) -> torch.Tensor:
+ # print("ARI.total",self.total)
+ return self.values / self.total
+
+
+class MOTMetric(torchmetrics.Metric, RoutableMixin):
+ def __init__(
+ self,
+ prediction_path: str,
+ target_path: str,
+ target_is_mask: bool = True,
+ use_threshold: bool = True,
+ threshold: float = 0.5,
+ ):
+ torchmetrics.Metric.__init__(self)
+ RoutableMixin.__init__(self, {"prediction": prediction_path, "target": target_path})
+ self.target_is_mask = target_is_mask
+ self.use_threshold = use_threshold
+ self.threshold = threshold
+ self.reset_accumulator()
+ self.accuracy = []
+
+ self.add_state(
+ "values", default=torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx="sum"
+ )
+ self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
+
+ def reset_accumulator(self):
+ self.acc = mm.MOTAccumulator(auto_id=True)
+
+ @RoutableMixin.route
+ def update(self, prediction: torch.Tensor, target: torch.Tensor):
+ # Merge batch and frame dimensions
+ B, F = prediction.shape[:2]
+ prediction = prediction.flatten(0, 1)
+ target = target.flatten(0, 1)
+
+ bs, n_pred_classes = prediction.shape[:2]
+ n_gt_classes = target.shape[1]
+
+ if self.use_threshold:
+ prediction = prediction > self.threshold
+ else:
+ indices = torch.argmax(prediction, dim=1)
+ prediction = torch.nn.functional.one_hot(indices, num_classes=n_pred_classes)
+ prediction = prediction.permute(0, 3, 1, 2)
+
+ pred_bboxes = masks_to_bboxes(prediction.flatten(0, 1)).unflatten(0, (B, F, n_pred_classes))
+ if self.target_is_mask:
+ target_bboxes = masks_to_bboxes(target.flatten(0, 1)).unflatten(0, (B, F, n_gt_classes))
+ else:
+ assert target.shape[-1] == 4
+ # Convert all-zero boxes added during padding to invalid boxes
+ target[torch.all(target == 0.0, dim=-1)] = -1.0
+ target_bboxes = target
+
+ self.reset_accumulator()
+ for preds, targets in zip(pred_bboxes, target_bboxes):
+ # seq evaluation
+ self.reset_accumulator()
+ for pred, target, mask in zip(preds, targets, prediction):
+ valid_track_box = pred[:, 0] != -1.0
+ valid_target_box = target[:, 0] != -1.0
+
+ track_id = valid_track_box.nonzero()[:, 0].detach().cpu().numpy()
+ target_id = valid_target_box.nonzero()[:, 0].detach().cpu().numpy()
+
+ # move background
+ idx = track_id.tolist()
+ for id in idx:
+ h, w = mask[id].shape
+ thres = h * w * 0.25
+ if pred[id][2] * pred[id][3] >= thres:
+ idx.remove(id)
+ cur_obj_idx = np.array(idx)
+
+ if valid_target_box.sum() == 0:
+ continue # Skip data points without any target bbox
+
+ pred = pred[cur_obj_idx].detach().cpu().numpy()
+ target = target[valid_target_box].detach().cpu().numpy()
+ # frame evaluation
+ self.eval_frame(pred, target, cur_obj_idx, target_id)
+ self.accuracy.append(self.acc)
+
+ self.total += 1
+
+ def eval_frame(self, trk_tlwhs, tgt_tlwhs, trk_ids, tgt_ids):
+ # get distance matrix
+ trk_tlwhs = np.copy(trk_tlwhs)
+ tgt_tlwhs = np.copy(tgt_tlwhs)
+ trk_ids = np.copy(trk_ids)
+ tgt_ids = np.copy(tgt_ids)
+ iou_distance = mm.distances.iou_matrix(tgt_tlwhs, trk_tlwhs, max_iou=0.5)
+ # acc
+ self.acc.update(tgt_ids, trk_ids, iou_distance)
+
+ def convert_motmetric_to_value(self, res):
+ dp = res.replace(" ", ";").replace(";;", ";").replace(";;", ";").replace(";;", ";")
+ tmp = list(dp)
+ tmp[0] = "-"
+ dp = "".join(tmp)
+ return io.StringIO(dp)
+
+ def compute(self) -> torch.Tensor:
+ if self.total == 0:
+ return torch.zeros_like(self.values)
+ else:
+ metrics = mm.metrics.motchallenge_metrics
+ mh = mm.metrics.create()
+ summary = mh.compute_many(
+ self.accuracy, metrics=metrics, names=None, generate_overall=True
+ )
+ strsummary = mm.io.render_summary(
+ summary, formatters=mh.formatters, namemap=mm.io.motchallenge_metric_names
+ )
+ res = self.convert_motmetric_to_value(strsummary)
+ df = pd.read_csv(res, sep=";", engine="python")
+
+ mota = df.iloc[-1]["MOTA"]
+ self.values = torch.tensor(float(mota[:-1]), dtype=torch.float64).to(self.values.device)
+ self.reset_accumulator()
+ self.accuracy = []
+ return self.values
+
+
+class PatchARIMetric(ARIMetric):
+ """Computes ARI metric assuming patch masks as input."""
+
+ def __init__(
+ self,
+ prediction_key: str,
+ target_key: str,
+ foreground=True,
+ resize_masks_mode: str = "bilinear",
+ **kwargs,
+ ):
+ super().__init__(prediction_key, target_key, foreground, **kwargs)
+ self.resize_masks_mode = resize_masks_mode
+
+ @RoutableMixin.route
+ def update(self, prediction: torch.Tensor, target: torch.Tensor):
+ """Update this metric.
+
+ Args:
+ prediction: Predicted mask of shape (B, C, P) or (B, F, C, P), where C is the
+ number of classes and P the number of patches.
+ target: Ground truth mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the
+ number of classes.
+ """
+ h, w = target.shape[-2:]
+ assert h == w
+
+ prediction_resized = resize_patches_to_image(
+ prediction, size=h, resize_mode=self.resize_masks_mode
+ )
+
+ return super().update(prediction=prediction_resized, target=target)
+
+
+def adjusted_rand_index(pred_mask: torch.Tensor, true_mask: torch.Tensor) -> torch.Tensor:
+ """Computes adjusted Rand index (ARI), a clustering similarity score.
+
+ This implementation ignores points with no cluster label in `true_mask` (i.e. those points for
+ which `true_mask` is a zero vector). In the context of segmentation, that means this function
+ can ignore points in an image corresponding to the background (i.e. not to an object).
+
+ Implementation adapted from https://github.com/deepmind/multi_object_datasets and
+ https://github.com/google-research/slot-attention-video/blob/main/savi/lib/metrics.py
+
+ Args:
+ pred_mask: Predicted cluster assignment encoded as categorical probabilities of shape
+ (batch_size, n_points, n_pred_clusters).
+ true_mask: True cluster assignment encoded as one-hot of shape (batch_size, n_points,
+ n_true_clusters).
+
+ Returns:
+ ARI scores of shape (batch_size,).
+ """
+ n_pred_clusters = pred_mask.shape[-1]
+ pred_cluster_ids = torch.argmax(pred_mask, axis=-1)
+
+ # Convert true and predicted clusters to one-hot ('oh') representations. We use float64 here on
+ # purpose, otherwise mixed precision training automatically casts to FP16 in some of the
+ # operations below, which can create overflows.
+ true_mask_oh = true_mask.to(torch.float64) # already one-hot
+ pred_mask_oh = torch.nn.functional.one_hot(pred_cluster_ids, n_pred_clusters).to(torch.float64)
+
+ n_ij = torch.einsum("bnc,bnk->bck", true_mask_oh, pred_mask_oh)
+ a = torch.sum(n_ij, axis=-1)
+ b = torch.sum(n_ij, axis=-2)
+ n_fg_points = torch.sum(a, axis=1)
+
+ rindex = torch.sum(n_ij * (n_ij - 1), axis=(1, 2))
+ aindex = torch.sum(a * (a - 1), axis=1)
+ bindex = torch.sum(b * (b - 1), axis=1)
+ expected_rindex = aindex * bindex / torch.clamp(n_fg_points * (n_fg_points - 1), min=1)
+ max_rindex = (aindex + bindex) / 2
+ denominator = max_rindex - expected_rindex
+ ari = (rindex - expected_rindex) / denominator
+
+ # There are two cases for which the denominator can be zero:
+ # 1. If both true_mask and pred_mask assign all pixels to a single cluster.
+ # (max_rindex == expected_rindex == rindex == n_fg_points * (n_fg_points-1))
+ # 2. If both true_mask and pred_mask assign max 1 point to each cluster.
+ # (max_rindex == expected_rindex == rindex == 0)
+ # In both cases, we want the ARI score to be 1.0:
+ return torch.where(denominator > 0, ari, torch.ones_like(ari))
+
+
+def fg_adjusted_rand_index(
+ pred_mask: torch.Tensor, true_mask: torch.Tensor, bg_dim: int = 0
+) -> torch.Tensor:
+ """Compute adjusted random index using only foreground groups (FG-ARI).
+
+ Args:
+ pred_mask: Predicted cluster assignment encoded as categorical probabilities of shape
+ (batch_size, n_points, n_pred_clusters).
+ true_mask: True cluster assignment encoded as one-hot of shape (batch_size, n_points,
+ n_true_clusters).
+ bg_dim: Index of background class in true mask.
+
+ Returns:
+ ARI scores of shape (batch_size,).
+ """
+ n_true_clusters = true_mask.shape[-1]
+ assert 0 <= bg_dim < n_true_clusters
+ if bg_dim == 0:
+ true_mask_only_fg = true_mask[..., 1:]
+ elif bg_dim == n_true_clusters - 1:
+ true_mask_only_fg = true_mask[..., :-1]
+ else:
+ true_mask_only_fg = torch.cat(
+ (true_mask[..., :bg_dim], true_mask[..., bg_dim + 1 :]), dim=-1
+ )
+
+ return adjusted_rand_index(pred_mask, true_mask_only_fg)
+
+
+def _all_equal_masked(values: torch.Tensor, mask: torch.Tensor, dim=-1) -> torch.Tensor:
+ """Check if all masked values along a dimension of a tensor are the same.
+
+ All non-masked values are considered as true, i.e. if no value is masked, true is returned
+ for this dimension.
+ """
+ assert mask.dtype == torch.bool
+ _, first_non_masked_idx = torch.max(mask, dim=dim)
+
+ comparison_value = values.gather(index=first_non_masked_idx.unsqueeze(dim), dim=dim)
+
+ return torch.logical_or(~mask, values == comparison_value).all(dim=dim)
+
+
+class UnsupervisedMaskIoUMetric(torchmetrics.Metric, RoutableMixin):
+ """Computes IoU metric for segmentation masks when correspondences to ground truth are not known.
+
+ Uses Hungarian matching to compute the assignment between predicted classes and ground truth
+ classes.
+
+ Args:
+ use_threshold: If `True`, convert predicted class probabilities to mask using a threshold.
+ If `False`, class probabilities are turned into mask using a softmax instead.
+ threshold: Value to use for thresholding masks.
+ matching: Approach to match predicted to ground truth classes. For "hungarian", computes
+ assignment that maximizes total IoU between all classes. For "best_overlap", uses the
+ predicted class with maximum overlap for each ground truth class. Using "best_overlap"
+ leads to the "average best overlap" metric.
+ compute_discovery_fraction: Instead of the IoU, compute the fraction of ground truth classes
+ that were "discovered", meaning that they have an IoU greater than some threshold.
+ correct_localization: Instead of the IoU, compute the fraction of images on which at least
+ one ground truth class was correctly localised, meaning that they have an IoU
+ greater than some threshold.
+ discovery_threshold: Minimum IoU to count a class as discovered/correctly localized.
+ ignore_background: If true, assume class at index 0 of ground truth masks is background class
+ that is removed before computing IoU.
+ ignore_overlaps: If true, remove points where ground truth masks has overlappign classes from
+ predictions and ground truth masks.
+ """
+
+ def __init__(
+ self,
+ prediction_path: str,
+ target_path: str,
+ ignore_path: Optional[str] = None,
+ use_threshold: bool = False,
+ threshold: float = 0.5,
+ matching: str = "hungarian",
+ compute_discovery_fraction: bool = False,
+ correct_localization: bool = False,
+ discovery_threshold: float = 0.5,
+ ignore_background: bool = False,
+ ignore_overlaps: bool = False,
+ ):
+ torchmetrics.Metric.__init__(self)
+ RoutableMixin.__init__(
+ self, {"prediction": prediction_path, "target": target_path, "ignore": ignore_path}
+ )
+ self.use_threshold = use_threshold
+ self.threshold = threshold
+ self.discovery_threshold = discovery_threshold
+ self.compute_discovery_fraction = compute_discovery_fraction
+ self.correct_localization = correct_localization
+ if compute_discovery_fraction and correct_localization:
+ raise ValueError(
+ "Only one of `compute_discovery_fraction` and `correct_localization` can be enabled."
+ )
+
+ matchings = ("hungarian", "best_overlap")
+ if matching not in matchings:
+ raise ValueError(f"Unknown matching type {matching}. Valid values are {matchings}.")
+ self.matching = matching
+ self.ignore_background = ignore_background
+ self.ignore_overlaps = ignore_overlaps
+
+ self.add_state(
+ "values", default=torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx="sum"
+ )
+ self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
+
+ @RoutableMixin.route
+ def update(
+ self, prediction: torch.Tensor, target: torch.Tensor, ignore: Optional[torch.Tensor] = None
+ ):
+ """Update this metric.
+
+ Args:
+ prediction: Predicted mask of shape (B, C, H, W) or (B, F, C, H, W), where C is the
+ number of classes. Assumes class probabilities as inputs.
+ target: Ground truth mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the
+ number of classes.
+ ignore: Ignore mask of shape (B, 1, H, W) or (B, 1, K, H, W)
+ """
+ if prediction.ndim == 5:
+ # Merge frames, height and width to single dimension.
+ predictions = prediction.transpose(1, 2).flatten(-3, -1)
+ targets = target.transpose(1, 2).flatten(-3, -1)
+ if ignore is not None:
+ ignore = ignore.to(torch.bool).transpose(1, 2).flatten(-3, -1)
+ elif prediction.ndim == 4:
+ # Merge height and width to single dimension.
+ predictions = prediction.flatten(-2, -1)
+ targets = target.flatten(-2, -1)
+ if ignore is not None:
+ ignore = ignore.to(torch.bool).flatten(-2, -1)
+ else:
+ raise ValueError(f"Incorrect input shape: f{prediction.shape}")
+
+ if self.use_threshold:
+ predictions = predictions > self.threshold
+ else:
+ indices = torch.argmax(predictions, dim=1)
+ predictions = torch.nn.functional.one_hot(indices, num_classes=predictions.shape[1])
+ predictions = predictions.transpose(1, 2)
+
+ if self.ignore_background:
+ targets = targets[:, 1:]
+
+ targets = targets > 0 # Ensure masks are binary
+
+ if self.ignore_overlaps:
+ overlaps = targets.sum(1, keepdim=True) > 1
+ if ignore is None:
+ ignore = overlaps
+ else:
+ ignore = ignore | overlaps
+
+ if ignore is not None:
+ assert ignore.ndim == 3 and ignore.shape[1] == 1
+ predictions[ignore.expand_as(predictions)] = 0
+ targets[ignore.expand_as(targets)] = 0
+
+ # Should be either 0 (empty, padding) or 1 (single object).
+ assert torch.all(targets.sum(dim=1) < 2), "Issues with target format, mask non-exclusive"
+
+ for pred, target in zip(predictions, targets):
+ nonzero_classes = torch.sum(target, dim=-1) > 0
+ target = target[nonzero_classes] # Remove empty (e.g. padded) classes
+ if len(target) == 0:
+ continue # Skip elements without any target mask
+
+ iou_per_class = unsupervised_mask_iou(
+ pred, target, matching=self.matching, reduction="none"
+ )
+
+ if self.compute_discovery_fraction:
+ discovered = iou_per_class > self.discovery_threshold
+ self.values += discovered.sum() / len(discovered)
+ elif self.correct_localization:
+ correctly_localized = torch.any(iou_per_class > self.discovery_threshold)
+ self.values += correctly_localized.sum()
+ else:
+ self.values += iou_per_class.mean()
+ self.total += 1
+
+ def compute(self) -> torch.Tensor:
+ # print("mIoU.total",self.total)
+ if self.total == 0:
+ return torch.zeros_like(self.values)
+ else:
+ return self.values / self.total
+
+
+def unsupervised_mask_iou(
+ pred_mask: torch.Tensor,
+ true_mask: torch.Tensor,
+ matching: str = "hungarian",
+ reduction: str = "mean",
+ iou_empty: float = 0.0,
+) -> torch.Tensor:
+ """Compute intersection-over-union (IoU) between masks with unknown class correspondences.
+
+ This metric is also known as Jaccard index. Note that this is a non-batched implementation.
+
+ Args:
+ pred_mask: Predicted mask of shape (C, N), where C is the number of predicted classes and
+ N is the number of points. Masks are assumed to be binary.
+ true_mask: Ground truth mask of shape (K, N), where K is the number of ground truth
+ classes and N is the number of points. Masks are assumed to be binary.
+ matching: How to match predicted classes to ground truth classes. For "hungarian", computes
+ assignment that maximizes total IoU between all classes. For "best_overlap", uses the
+ predicted class with maximum overlap for each ground truth class (each predicted class
+ can be assigned to multiple ground truth classes). Empty ground truth classes are
+ assigned IoU of zero.
+ reduction: If "mean", return IoU averaged over classes. If "none", return per-class IoU.
+ iou_empty: IoU for the case when a class does not occur, but was also not predicted.
+
+ Returns:
+ Mean IoU over classes if reduction is `mean`, tensor of shape (K,) containing per-class IoU
+ otherwise.
+ """
+ assert pred_mask.ndim == 2
+ assert true_mask.ndim == 2
+ n_gt_classes = len(true_mask)
+ pred_mask = pred_mask.unsqueeze(1).to(torch.bool)
+ true_mask = true_mask.unsqueeze(0).to(torch.bool)
+
+ intersection = torch.sum(pred_mask & true_mask, dim=-1).to(torch.float64)
+ union = torch.sum(pred_mask | true_mask, dim=-1).to(torch.float64)
+ pairwise_iou = intersection / union
+
+ # Remove NaN from divide-by-zero: class does not occur, and class was not predicted.
+ pairwise_iou[union == 0] = iou_empty
+
+ if matching == "hungarian":
+ pred_idxs, true_idxs = scipy.optimize.linear_sum_assignment(
+ pairwise_iou.cpu(), maximize=True
+ )
+ pred_idxs = torch.as_tensor(pred_idxs, dtype=torch.int64, device=pairwise_iou.device)
+ true_idxs = torch.as_tensor(true_idxs, dtype=torch.int64, device=pairwise_iou.device)
+ elif matching == "best_overlap":
+ non_empty_gt = torch.sum(true_mask.squeeze(0), dim=1) > 0
+ pred_idxs = torch.argmax(pairwise_iou, dim=0)[non_empty_gt]
+ true_idxs = torch.arange(pairwise_iou.shape[1])[non_empty_gt]
+ else:
+ raise ValueError(f"Unknown matching {matching}")
+
+ matched_iou = pairwise_iou[pred_idxs, true_idxs]
+ iou = torch.zeros(n_gt_classes, dtype=torch.float64, device=pairwise_iou.device)
+ iou[true_idxs] = matched_iou
+
+ if reduction == "mean":
+ return iou.mean()
+ else:
+ return iou
+
+
+class UnsupervisedBboxIoUMetric(torchmetrics.Metric, RoutableMixin):
+ """Computes IoU metric for bounding boxes when correspondences to ground truth are not known.
+
+ Currently, assumes segmentation masks as input for both prediction and targets.
+
+ Args:
+ target_is_mask: If `True`, assume input is a segmentation mask, in which case the masks are
+ converted to bounding boxes before computing IoU. If `False`, assume the input for the
+ targets are already bounding boxes.
+ use_threshold: If `True`, convert predicted class probabilities to mask using a threshold.
+ If `False`, class probabilities are turned into mask using a softmax instead.
+ threshold: Value to use for thresholding masks.
+ matching: How to match predicted boxes to ground truth boxes. For "hungarian", computes
+ assignment that maximizes total IoU between all boxes. For "best_overlap", uses the
+ predicted box with maximum overlap for each ground truth box (each predicted box
+ can be assigned to multiple ground truth boxes).
+ compute_discovery_fraction: Instead of the IoU, compute the fraction of ground truth classes
+ that were "discovered", meaning that they have an IoU greater than some threshold. This
+ is recall, or sometimes called the detection rate metric.
+ correct_localization: Instead of the IoU, compute the fraction of images on which at least
+ one ground truth bounding box was correctly localised, meaning that they have an IoU
+ greater than some threshold.
+ discovery_threshold: Minimum IoU to count a class as discovered/correctly localized.
+ """
+
+ def __init__(
+ self,
+ prediction_path: str,
+ target_path: str,
+ target_is_mask: bool = False,
+ use_threshold: bool = False,
+ threshold: float = 0.5,
+ matching: str = "hungarian",
+ compute_discovery_fraction: bool = False,
+ correct_localization: bool = False,
+ discovery_threshold: float = 0.5,
+ ):
+ torchmetrics.Metric.__init__(self)
+ RoutableMixin.__init__(self, {"prediction": prediction_path, "target": target_path})
+ self.target_is_mask = target_is_mask
+ self.use_threshold = use_threshold
+ self.threshold = threshold
+ self.discovery_threshold = discovery_threshold
+ self.compute_discovery_fraction = compute_discovery_fraction
+ self.correct_localization = correct_localization
+ if compute_discovery_fraction and correct_localization:
+ raise ValueError(
+ "Only one of `compute_discovery_fraction` and `correct_localization` can be enabled."
+ )
+
+ matchings = ("hungarian", "best_overlap")
+ if matching not in matchings:
+ raise ValueError(f"Unknown matching type {matching}. Valid values are {matchings}.")
+ self.matching = matching
+
+ self.add_state(
+ "values", default=torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx="sum"
+ )
+ self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
+
+ @RoutableMixin.route
+ def update(self, prediction: torch.Tensor, target: torch.Tensor):
+ """Update this metric.
+
+ Args:
+ prediction: Predicted mask of shape (B, C, H, W) or (B, F, C, H, W), where C is the
+ number of instances. Assumes class probabilities as inputs.
+ target: Ground truth mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the
+ number of instance, if using masks as input, or bounding boxes of shape (B, K, 4)
+ or (B, F, K, 4).
+ """
+ if prediction.ndim == 5:
+ # Merge batch and frame dimensions
+ prediction = prediction.flatten(0, 1)
+ target = target.flatten(0, 1)
+ elif prediction.ndim != 4:
+ raise ValueError(f"Incorrect input shape: f{prediction.shape}")
+
+ bs, n_pred_classes = prediction.shape[:2]
+ n_gt_classes = target.shape[1]
+
+ if self.use_threshold:
+ prediction = prediction > self.threshold
+ else:
+ indices = torch.argmax(prediction, dim=1)
+ prediction = torch.nn.functional.one_hot(indices, num_classes=n_pred_classes)
+ prediction = prediction.permute(0, 3, 1, 2)
+
+ pred_bboxes = masks_to_bboxes(prediction.flatten(0, 1)).unflatten(0, (bs, n_pred_classes))
+ if self.target_is_mask:
+ target_bboxes = masks_to_bboxes(target.flatten(0, 1)).unflatten(0, (bs, n_gt_classes))
+ else:
+ assert target.shape[-1] == 4
+ # Convert all-zero boxes added during padding to invalid boxes
+ target[torch.all(target == 0.0, dim=-1)] = -1.0
+ target_bboxes = target
+
+ for pred, target in zip(pred_bboxes, target_bboxes):
+ valid_pred_bboxes = pred[:, 0] != -1.0
+ valid_target_bboxes = target[:, 0] != -1.0
+ if valid_target_bboxes.sum() == 0:
+ continue # Skip data points without any target bbox
+
+ pred = pred[valid_pred_bboxes]
+ target = target[valid_target_bboxes]
+
+ if valid_pred_bboxes.sum() > 0:
+ iou_per_bbox = unsupervised_bbox_iou(
+ pred, target, matching=self.matching, reduction="none"
+ )
+ else:
+ iou_per_bbox = torch.zeros_like(valid_target_bboxes, dtype=torch.float32)
+
+ if self.compute_discovery_fraction:
+ discovered = iou_per_bbox > self.discovery_threshold
+ self.values += discovered.sum() / len(iou_per_bbox)
+ elif self.correct_localization:
+ correctly_localized = torch.any(iou_per_bbox > self.discovery_threshold)
+ self.values += correctly_localized.sum()
+ else:
+ self.values += iou_per_bbox.mean()
+ self.total += 1
+
+ def compute(self) -> torch.Tensor:
+ if self.total == 0:
+ return torch.zeros_like(self.values)
+ else:
+ return self.values / self.total
+
+
+def unsupervised_bbox_iou(
+ pred_bboxes: torch.Tensor,
+ true_bboxes: torch.Tensor,
+ matching: str = "best_overlap",
+ reduction: str = "mean",
+) -> torch.Tensor:
+ """Compute IoU between two sets of bounding boxes.
+
+ Args:
+ pred_bboxes: Predicted bounding boxes of shape N x 4.
+ true_bboxes: True bounding boxes of shape M x 4.
+ matching: Method to assign predicted to true bounding boxes.
+ reduction: Whether to average the computes IoUs per true box.
+ """
+ n_gt_bboxes = len(true_bboxes)
+
+ pairwise_iou = torchvision.ops.box_iou(pred_bboxes, true_bboxes)
+
+ if matching == "hungarian":
+ pred_idxs, true_idxs = scipy.optimize.linear_sum_assignment(
+ pairwise_iou.cpu(), maximize=True
+ )
+ pred_idxs = torch.as_tensor(pred_idxs, dtype=torch.int64, device=pairwise_iou.device)
+ true_idxs = torch.as_tensor(true_idxs, dtype=torch.int64, device=pairwise_iou.device)
+ elif matching == "best_overlap":
+ pred_idxs = torch.argmax(pairwise_iou, dim=0)
+ true_idxs = torch.arange(pairwise_iou.shape[1], device=pairwise_iou.device)
+ else:
+ raise ValueError(f"Unknown matching {matching}")
+
+ matched_iou = pairwise_iou[pred_idxs, true_idxs]
+
+ iou = torch.zeros(n_gt_bboxes, dtype=torch.float32, device=pairwise_iou.device)
+ iou[true_idxs] = matched_iou
+
+ if reduction == "mean":
+ return iou.mean()
+ else:
+ return iou
+
+
+def masks_to_bboxes(masks: torch.Tensor, empty_value: float = -1.0) -> torch.Tensor:
+ """Compute bounding boxes around the provided masks.
+
+ Adapted from DETR: https://github.com/facebookresearch/detr/blob/main/util/box_ops.py
+
+ Args:
+ masks: Tensor of shape (N, H, W), where N is the number of masks, H and W are the spatial
+ dimensions.
+ empty_value: Value bounding boxes should contain for empty masks.
+
+ Returns:
+ Tensor of shape (N, 4), containing bounding boxes in (x1, y1, x2, y2) format, where (x1, y1)
+ is the coordinate of top-left corner and (x2, y2) is the coordinate of the bottom-right
+ corner (inclusive) in pixel coordinates. If mask is empty, all coordinates contain
+ `empty_value` instead.
+ """
+ masks = masks.bool()
+ if masks.numel() == 0:
+ return torch.zeros((0, 4), device=masks.device)
+
+ large_value = 1e8
+ inv_mask = ~masks
+
+ h, w = masks.shape[-2:]
+
+ y = torch.arange(0, h, dtype=torch.float, device=masks.device)
+ x = torch.arange(0, w, dtype=torch.float, device=masks.device)
+ y, x = torch.meshgrid(y, x, indexing="ij")
+
+ x_mask = masks * x.unsqueeze(0)
+ x_max = x_mask.flatten(1).max(-1)[0]
+ x_min = x_mask.masked_fill(inv_mask, large_value).flatten(1).min(-1)[0]
+
+ y_mask = masks * y.unsqueeze(0)
+ y_max = y_mask.flatten(1).max(-1)[0]
+ y_min = y_mask.masked_fill(inv_mask, large_value).flatten(1).min(-1)[0]
+
+ bboxes = torch.stack((x_min, y_min, x_max, y_max), dim=1)
+ bboxes[x_min == large_value] = empty_value
+
+ return bboxes
+
+
+class DatasetSemanticMaskIoUMetric(torchmetrics.Metric):
+ """Unsupervised IoU metric for semantic segmentation using dataset-wide matching of classes.
+
+ The input to this metric is an instance-level mask with objects, and a class id for each object.
+ This is required to convert the mask to semantic classes. The number of classes for the
+ predictions does not have to match the true number of classes.
+
+ Note that contrary to the other metrics in this module, this metric is not supposed to be added
+ in the online metric computation loop, which is why it does not inherit from `RoutableMixin`.
+
+ Args:
+ n_predicted_classes: Number of predictable classes, i.e. highest prediction class id that can
+ occur.
+ n_classes: Total number of classes, i.e. highest class id that can occur.
+ threshold: Value to use for thresholding masks.
+ use_threshold: If `True`, convert predicted class probabilities to mask using a threshold.
+ If `False`, class probabilities are turned into mask using an argmax instead.
+ matching: Method to produce matching between clusters and ground truth classes. If
+ "hungarian", assigns each class one cluster such that the total IoU is maximized. If
+ "majority", assigns each cluster to the class with the highest IoU (each class can be
+ assigned multiple clusters).
+ ignore_background: If true, pixels labeled as background (class zero) in the ground truth
+ are not taken into account when computing IoU.
+ use_unmatched_as_background: If true, count predicted classes not selected after Hungarian
+ matching as the background predictions.
+ """
+
+ def __init__(
+ self,
+ n_predicted_classes: int,
+ n_classes: int,
+ use_threshold: bool = False,
+ threshold: float = 0.5,
+ matching: str = "hungarian",
+ ignore_background: bool = False,
+ use_unmatched_as_background: bool = False,
+ ):
+ super().__init__()
+ matching_methods = {"hungarian", "majority"}
+ if matching not in matching_methods:
+ raise ValueError(
+ f"Unknown matching method {matching}. Valid values are {matching_methods}."
+ )
+
+ self.matching = matching
+ self.n_predicted_classes = n_predicted_classes
+ self.n_predicted_classes_with_bg = n_predicted_classes + 1
+ self.n_classes = n_classes
+ self.n_classes_with_bg = n_classes + 1
+ self.matching = matching
+ self.use_threshold = use_threshold
+ self.threshold = threshold
+ self.ignore_background = ignore_background
+ self.use_unmatched_as_background = use_unmatched_as_background
+ if use_unmatched_as_background and ignore_background:
+ raise ValueError(
+ "Option `use_unmatched_as_background` not compatible with option `ignore_background`"
+ )
+ if use_unmatched_as_background and matching == "majority":
+ raise ValueError(
+ "Option `use_unmatched_as_background` not compatible with matching `majority`"
+ )
+
+ confusion_mat = torch.zeros(
+ self.n_predicted_classes_with_bg, self.n_classes_with_bg, dtype=torch.int64
+ )
+ self.add_state("confusion_mat", default=confusion_mat, dist_reduce_fx="sum", persistent=True)
+
+ def update(
+ self,
+ predictions: torch.Tensor,
+ targets: torch.Tensor,
+ prediction_class_ids: torch.Tensor,
+ ignore: Optional[torch.Tensor] = None,
+ ):
+ """Update metric by computing confusion matrix between predicted and target classes.
+
+ Args:
+ predictions: Probability mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the
+ number of object instances in the image.
+ targets: Mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the number of object
+ instances in the image. Class ID of objects is encoded as the value, i.e. densely
+ represented.
+ prediction_class_ids: Tensor of shape (B, K), containing the class id of each predicted
+ object instance in the image. Id must be 0 <= id <= n_predicted_classes.
+ ignore: Ignore mask of shape (B, 1, H, W) or (B, 1, K, H, W)
+ """
+ predictions = self.preprocess_predicted_mask(predictions)
+ predictions = _remap_one_hot_mask(
+ predictions, prediction_class_ids, self.n_predicted_classes, strip_empty=False
+ )
+ assert predictions.shape[-1] == self.n_predicted_classes_with_bg
+
+ targets = self.preprocess_ground_truth_mask(targets)
+ assert targets.shape[-1] == self.n_classes_with_bg
+
+ if ignore is not None:
+ if ignore.ndim == 5: # Video case
+ ignore = ignore.flatten(0, 1)
+ assert ignore.ndim == 4 and ignore.shape[1] == 1
+ ignore = ignore.to(torch.bool).flatten(-2, -1).squeeze(1) # B x P
+ predictions[ignore] = 0
+ targets[ignore] = 0
+
+ # We are doing the multiply in float64 instead of int64 because it proved to be significantly
+ # faster on GPU. We need to use 64 bits because we can easily exceed the range of 32 bits
+ # if we aggregate over a full dataset.
+ confusion_mat = torch.einsum(
+ "bpk,bpc->kc", predictions.to(torch.float64), targets.to(torch.float64)
+ )
+ self.confusion_mat += confusion_mat.to(torch.int64)
+
+ def preprocess_predicted_mask(self, mask: torch.Tensor) -> torch.Tensor:
+ """Preprocess predicted masks for metric computation.
+
+ Args:
+ mask: Probability mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the number
+ of object instances in the prediction.
+
+ Returns:
+ Binary tensor of shape (B, P, K), where P is the number of points. If `use_threshold` is
+ True, overlapping objects for the same point are possible.
+ """
+ if mask.ndim == 5: # Video case
+ mask = mask.flatten(0, 1)
+ mask = mask.flatten(-2, -1)
+
+ if self.use_threshold:
+ mask = mask > self.threshold
+ mask = mask.transpose(1, 2)
+ else:
+ maximum, indices = torch.max(mask, dim=1)
+ mask = torch.nn.functional.one_hot(indices, num_classes=mask.shape[1])
+ mask[:, :, 0][maximum == 0.0] = 0
+
+ return mask
+
+ def preprocess_ground_truth_mask(self, mask: torch.Tensor) -> torch.Tensor:
+ """Preprocess ground truth mask for metric computation.
+
+ Args:
+ mask: Mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the number of object
+ instances in the image. Class ID of objects is encoded as the value, i.e. densely
+ represented.
+
+ Returns:
+ One-hot tensor of shape (B, P, J), where J is the number of the classes and P the number
+ of points, with object instances with the same class ID merged together. In the case of
+ an overlap of classes for a point, the class with the highest ID is assigned to that
+ point.
+ """
+ if mask.ndim == 5: # Video case
+ mask = mask.flatten(0, 1)
+ mask = mask.flatten(-2, -1)
+
+ # Pixels which contain no object get assigned the background class 0. This also handles the
+ # padding of zero masks which is done in preprocessing for batching.
+ mask = torch.nn.functional.one_hot(
+ mask.max(dim=1).values.to(torch.long), num_classes=self.n_classes_with_bg
+ )
+
+ return mask
+
+ def compute(self):
+ """Compute per-class IoU using matching."""
+ if self.ignore_background:
+ n_classes = self.n_classes
+ confusion_mat = self.confusion_mat[:, 1:]
+ else:
+ n_classes = self.n_classes_with_bg
+ confusion_mat = self.confusion_mat
+
+ pairwise_iou, _, _, area_gt = self._compute_iou_from_confusion_mat(confusion_mat)
+
+ if self.use_unmatched_as_background:
+ # Match only in foreground
+ pairwise_iou = pairwise_iou[1:, 1:]
+ confusion_mat = confusion_mat[1:, 1:]
+ else:
+ # Predicted class zero is not matched against anything
+ pairwise_iou = pairwise_iou[1:]
+ confusion_mat = confusion_mat[1:]
+
+ if self.matching == "hungarian":
+ cluster_idxs, class_idxs = scipy.optimize.linear_sum_assignment(
+ pairwise_iou.cpu(), maximize=True
+ )
+ cluster_idxs = torch.as_tensor(
+ cluster_idxs, dtype=torch.int64, device=self.confusion_mat.device
+ )
+ class_idxs = torch.as_tensor(
+ class_idxs, dtype=torch.int64, device=self.confusion_mat.device
+ )
+ matched_iou = pairwise_iou[cluster_idxs, class_idxs]
+ true_pos = confusion_mat[cluster_idxs, class_idxs]
+
+ if self.use_unmatched_as_background:
+ cluster_oh = torch.nn.functional.one_hot(
+ cluster_idxs, num_classes=pairwise_iou.shape[0]
+ )
+ matched_clusters = cluster_oh.max(dim=0).values.to(torch.bool)
+ bg_pred = self.confusion_mat[:1]
+ bg_pred += self.confusion_mat[1:][~matched_clusters].sum(dim=0)
+ bg_iou, _, _, _ = self._compute_iou_from_confusion_mat(bg_pred, area_gt)
+ class_idxs = torch.cat((torch.zeros_like(class_idxs[:1]), class_idxs + 1))
+ matched_iou = torch.cat((bg_iou[0, :1], matched_iou))
+ true_pos = torch.cat((bg_pred[0, :1], true_pos))
+
+ elif self.matching == "majority":
+ max_iou, class_idxs = torch.max(pairwise_iou, dim=1)
+ # Form new clusters by merging old clusters which are assigned the same ground truth
+ # class. After merging, the number of clusters equals the number of classes.
+ _, old_to_new_cluster_idx = torch.unique(class_idxs, return_inverse=True)
+
+ confusion_mat_new = torch.zeros(
+ n_classes, n_classes, dtype=torch.int64, device=self.confusion_mat.device
+ )
+ for old_cluster_idx, new_cluster_idx in enumerate(old_to_new_cluster_idx):
+ if max_iou[old_cluster_idx] > 0.0:
+ confusion_mat_new[new_cluster_idx] += confusion_mat[old_cluster_idx]
+
+ # Important: use previously computed area_gt because it includes background predictions,
+ # whereas the new confusion matrix does not contain the bg predicted class anymore.
+ pairwise_iou, _, _, _ = self._compute_iou_from_confusion_mat(confusion_mat_new, area_gt)
+ max_iou, class_idxs = torch.max(pairwise_iou, dim=1)
+ valid = max_iou > 0.0 # Ignore clusters without any kind of overlap
+ class_idxs = class_idxs[valid]
+ cluster_idxs = torch.arange(pairwise_iou.shape[1])[valid]
+ matched_iou = pairwise_iou[cluster_idxs, class_idxs]
+ true_pos = confusion_mat_new[cluster_idxs, class_idxs]
+
+ iou = torch.zeros(n_classes, dtype=torch.float64, device=pairwise_iou.device)
+ iou[class_idxs] = matched_iou
+
+ accuracy = true_pos.sum().to(torch.float64) / area_gt.sum()
+ empty_classes = area_gt == 0
+
+ return iou, accuracy, empty_classes
+
+ @staticmethod
+ def _compute_iou_from_confusion_mat(
+ confusion_mat: torch.Tensor, area_gt: Optional[torch.Tensor] = None
+ ):
+ area_pred = torch.sum(confusion_mat, axis=1)
+ if area_gt is None:
+ area_gt = torch.sum(confusion_mat, axis=0)
+ union = area_pred.unsqueeze(1) + area_gt.unsqueeze(0) - confusion_mat
+ pairwise_iou = confusion_mat.to(torch.float64) / union
+
+ # Ignore classes that occured on no image.
+ pairwise_iou[union == 0] = 0.0
+
+ return pairwise_iou, union, area_pred, area_gt
+
+
+def _remap_one_hot_mask(
+ mask: torch.Tensor, new_classes: torch.Tensor, n_new_classes: int, strip_empty: bool = False
+):
+ """Remap classes from binary mask to new classes.
+
+ In the case of an overlap of classes for a point, the new class with the highest ID is
+ assigned to that point. If no class is assigned to a point, the point will have no class
+ assigned after remapping as well.
+
+ Args:
+ mask: Binary mask of shape (B, P, K) where K is the number of old classes and P is the
+ number of points.
+ new_classes: Tensor of shape (B, K) containing ids of new classes for each old class.
+ n_new_classes: Number of classes after remapping, i.e. highest class id that can occur.
+ strip_empty: Whether to remove the empty pixels mask
+
+ Returns:
+ Tensor of shape (B, P, J), where J is the new number of classes.
+ """
+ assert new_classes.shape[1] == mask.shape[2]
+ mask_dense = (mask * new_classes.unsqueeze(1)).max(dim=-1).values
+ mask = torch.nn.functional.one_hot(mask_dense.to(torch.long), num_classes=n_new_classes + 1)
+
+ if strip_empty:
+ mask = mask[..., 1:]
+
+ return mask
+
+
+class SklearnClustering:
+ """Wrapper around scikit-learn clustering algorithms.
+
+ Args:
+ n_clusters: Number of clusters.
+ method: Clustering method to use.
+ clustering_kwargs: Dictionary of additional keyword arguments to pass to clustering object.
+ use_l2_normalization: Whether to L2 normalize the representations before clustering (but
+ after PCA).
+ use_pca: Whether to apply PCA before fitting the clusters.
+ pca_dimensions: Number of dimensions for PCA dimensionality reduction. If `None`, do not
+ reduce dimensions with PCA.
+ pca_kwargs: Dictionary of additional keyword arguments to pass to PCA object.
+ """
+
+ def __init__(
+ self,
+ n_clusters: int,
+ method: str = "kmeans",
+ clustering_kwargs: Optional[Dict[str, Any]] = None,
+ use_l2_normalization: bool = False,
+ use_pca: bool = False,
+ pca_dimensions: Optional[int] = None,
+ pca_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ methods = ("kmeans", "spectral")
+ if method not in methods:
+ raise ValueError(f"Unknown clustering method {method}. Valid values are {methods}.")
+
+ self._n_clusters = n_clusters
+ self.method = method
+ self.clustering_kwargs = clustering_kwargs
+ self.use_l2_normalization = use_l2_normalization
+ self.use_pca = use_pca
+ self.pca_dimensions = pca_dimensions
+ self.pca_kwargs = pca_kwargs
+
+ self._clustering = None
+ self._pca = None
+
+ @property
+ def n_clusters(self):
+ return self._n_clusters
+
+ def _init(self):
+ from sklearn import cluster, decomposition
+
+ kwargs = self.clustering_kwargs if self.clustering_kwargs is not None else {}
+ if self.method == "kmeans":
+ self._clustering = cluster.KMeans(n_clusters=self.n_clusters, **kwargs)
+ elif self.method == "spectral":
+ self._clustering = cluster.SpectralClustering(n_clusters=self.n_clusters, **kwargs)
+ else:
+ raise NotImplementedError(f"Clustering {self.method} not implemented.")
+
+ if self.use_pca:
+ kwargs = self.pca_kwargs if self.pca_kwargs is not None else {}
+ self._pca = decomposition.PCA(n_components=self.pca_dimensions, **kwargs)
+
+ def fit_predict(self, features: torch.Tensor):
+ self._init()
+ features = features.detach().cpu().numpy()
+ if self.use_pca:
+ features = self._pca.fit_transform(features)
+ if self.use_l2_normalization:
+ features /= np.maximum(np.linalg.norm(features, ord=2, axis=1, keepdims=True), 1e-8)
+ cluster_ids = self._clustering.fit_predict(features).astype(np.int64)
+ return torch.from_numpy(cluster_ids)
+
+ def predict(self, features: torch.Tensor) -> torch.Tensor:
+ if self._clustering is None:
+ raise ValueError("Clustering was not fitted. Call `fit_predict` first.")
+
+ features = features.detach().cpu().numpy()
+ if self.use_pca:
+ features = self._pca.transform(features)
+ if self.use_l2_normalization:
+ features /= np.maximum(np.linalg.norm(features, ord=2, axis=1, keepdims=True), 1e-8)
+ cluster_ids = self._clustering.predict(features).astype(np.int64)
+ return torch.from_numpy(cluster_ids)
+
+
+from sklearn.metrics.cluster import adjusted_mutual_info_score, normalized_mutual_info_score, mutual_info_score, rand_score, fowlkes_mallows_score, adjusted_rand_score, \
+pair_confusion_matrix, contingency_matrix
+class MutualInfoAndPairCounting(torchmetrics.Metric, RoutableMixin):
+ """Computes Precision, Recall, F1, AMI, NMI and Purity metric."""
+
+ def __init__(
+ self,
+ prediction_path: str,
+ target_path: str,
+ ignore_path: Optional[str] = None,
+ foreground: bool = True,
+ convert_target_one_hot: bool = False,
+ ignore_overlaps: bool = False,
+ back_as_class = True,
+ metric_name = "ari_sklearn"
+ ):
+ torchmetrics.Metric.__init__(self)
+ RoutableMixin.__init__(
+ self, {"prediction": prediction_path, "target": target_path, "ignore": ignore_path}
+ )
+
+ self.convert_target_one_hot = convert_target_one_hot
+ self.foreground = foreground
+ self.ignore_overlaps = ignore_overlaps
+
+ self.back_as_class = back_as_class
+ self.metric_name = metric_name
+ assert self.metric_name in ("ari_sklearn","ami","nmi","precision","recall","f1","purity")
+
+ self.add_state(
+ "values", default=torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx="sum"
+ )
+ self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
+
+ @RoutableMixin.route
+ def update(
+ self, prediction: torch.Tensor, target: torch.Tensor, ignore: Optional[torch.Tensor] = None
+ ):
+ """Update this metric.
+
+ Args:
+ prediction: Predicted mask of shape (B, C, H, W) or (B, F, C, H, W), where C is the
+ number of classes.
+ target: Ground truth mask of shape (B, K, H, W) or (B, F, K, H, W), where K is the
+ number of classes.
+ ignore: Ignore mask of shape (B, 1, H, W) or (B, 1, K, H, W)
+ """
+ if prediction.ndim == 5:
+ # Merge frames, height and width to single dimension.
+ prediction = prediction.transpose(1, 2).flatten(-3, -1)
+ target = target.transpose(1, 2).flatten(-3, -1)
+ if ignore is not None:
+ ignore = ignore.to(torch.bool).transpose(1, 2).flatten(-3, -1)
+ elif prediction.ndim == 4:
+ # Merge height and width to single dimension.
+ prediction = prediction.flatten(-2, -1)
+ target = target.flatten(-2, -1)
+ if ignore is not None:
+ ignore = ignore.to(torch.bool).flatten(-2, -1)
+ else:
+ raise ValueError(f"Incorrect input shape: f{prediction.shape}")
+
+ if self.ignore_overlaps:
+ overlaps = (target > 0).sum(1, keepdim=True) > 1
+ if ignore is None:
+ ignore = overlaps
+ else:
+ ignore = ignore | overlaps
+
+ if ignore is not None:
+ assert ignore.ndim == 3 and ignore.shape[1] == 1
+ prediction = prediction.clone()
+ prediction[ignore.expand_as(prediction)] = 0
+ target = target.clone()
+ target[ignore.expand_as(target)] = 0
+
+ # Make channels / gt labels the last dimension.
+ prediction = prediction.transpose(-2, -1)
+ target = target.transpose(-2, -1)
+
+ if self.convert_target_one_hot:
+ target_oh = tensor_to_one_hot(target, dim=2)
+ # For empty pixels (all values zero), one-hot assigns 1 to the first class, correct for
+ # this (then it is technically not one-hot anymore).
+ target_oh[:, :, 0][target.sum(dim=2) == 0] = 0
+ target = target_oh
+
+ # Should be either 0 (empty, padding) or 1 (single object).
+ assert torch.all(target.sum(dim=-1) < 2), "Issues with target format, mask non-exclusive"
+
+ for pred, target in zip(prediction.permute(0, 2, 1), target.permute(0, 2, 1)):
+ #(b, n, k)->(b, k, n)
+ #(k, n)
+ if self.foreground:
+ if self.back_as_class:
+ target = target[1:]
+ fore_ground_point = target.sum(-2) > 0
+ target = target[..., fore_ground_point]
+ pred = pred[..., fore_ground_point]
+
+ nonzero_classes = torch.sum(target, dim=-1) > 0
+ target = target[nonzero_classes] # Remove empty (e.g. padded) classes
+ if len(target) == 0:
+ continue # Skip elements without any target mask
+
+ pred = pred.argmax(-2).detach().cpu().numpy()
+ target = target.argmax(-2).detach().cpu().numpy()
+ if self.metric_name == "purity":
+ contingency = contingency_matrix(target, pred)
+ true_idxs, pred_idxs = scipy.optimize.linear_sum_assignment(
+ contingency, maximize=True
+ )
+ # acc = np.sum(np.amax(contingency, axis=0)) / np.sum(contingency)
+ purity = np.sum([contingency[i,j] for i,j in zip(true_idxs, pred_idxs)]) / np.sum(contingency)
+ self.values += purity
+ self.total += 1
+
+ if self.metric_name == "ami":
+ ami = adjusted_mutual_info_score(target, pred)
+ self.values += ami
+ self.total += 1
+ if self.metric_name == "nmi":
+ nmi = normalized_mutual_info_score(target, pred)
+ self.values += nmi
+ self.total += 1
+ if self.metric_name == "ari_sklearn":
+ ari = adjusted_rand_score(target, pred)
+ self.values += ari
+ self.total += 1
+
+ if self.metric_name in ("precision","recall","f1"):
+ ins_confusion_matrix = pair_confusion_matrix(target, pred)
+ tn = ins_confusion_matrix[0,0]
+ fn = ins_confusion_matrix[1,0]
+ tp = ins_confusion_matrix[1,1]
+ fp = ins_confusion_matrix[0,1]
+ precision = tp/(tp +fp)
+ recall = tp/(tp +fn)
+ F1 = 2 * precision * recall/(precision + recall)
+
+ if self.metric_name == "precision" and not np.isnan(precision):
+ self.values += precision
+ self.total += 1
+
+ if self.metric_name == "recall" and not np.isnan(recall):
+ self.values += recall
+ self.total += 1
+
+ if self.metric_name == "f1" and not np.isnan(F1):
+ self.values += F1
+ self.total += 1
+
+
+ def compute(self) -> torch.Tensor:
+ return self.values/self.total
+
diff --git a/ocl/mha.py b/ocl/mha.py
new file mode 100644
index 0000000..d9cae1f
--- /dev/null
+++ b/ocl/mha.py
@@ -0,0 +1,130 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class ScaledDotProductAttention(nn.Module):
+ """Scaled Dot-Product Attention."""
+
+ def __init__(self, temperature, attn_dropout=0.0):
+ super().__init__()
+ self.temperature = temperature
+ self.dropout = nn.Dropout(attn_dropout)
+
+ def forward(self, q, k, v, mask=None):
+ attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
+
+ # if mask is not None:
+ # attn = attn.masked_fill(mask == 0, -1e9)
+ if mask is not None:
+ bias = (1 - mask) * (-1e9)
+ attn = attn * mask + bias
+
+ attn = F.softmax(attn, dim=-1)
+ output = torch.matmul(attn, v)
+
+ return output, attn
+
+
+class MultiHeadAttention_for_index(nn.Module):
+ """Multi-Head Attention module."""
+
+ def __init__(self, n_head, d_model, d_k, d_v, dropout=0):
+ super().__init__()
+
+ self.n_head = n_head
+ self.d_k = d_k
+ self.d_v = d_v
+
+ self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
+ self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
+ self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
+ self.fc = nn.Linear(n_head * d_v, d_model, bias=False)
+
+ nn.init.eye_(self.w_ks.weight)
+ nn.init.eye_(self.w_vs.weight)
+
+ self.attention = ScaledDotProductAttention(temperature=d_k**0.5) # temperature=d_k ** 0.5
+
+ self.dropout = nn.Dropout(dropout)
+ self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
+
+ def forward(self, q, k, v, mask=None):
+ d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
+ sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
+
+ # Pass through the pre-attention projection: b x lq x (n*dv)
+ # Separate different heads: b x lq x n x dv
+ q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
+ k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
+ v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
+
+ # Transpose for attention dot product: b x n x lq x dv
+ q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
+
+ if mask is not None:
+ mask = mask.unsqueeze(1) # For head axis broadcasting.
+
+ q, attn = self.attention(q, k, v, mask=mask)
+
+ # Transpose to move the head dimension back: b x lq x n x dv
+ # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
+ q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
+ q = self.dropout(self.fc(q))
+
+ q = self.layer_norm(q)
+
+ attn = torch.mean(attn, 1)
+ return q, attn
+
+
+class MultiHeadAttention(nn.Module):
+ """Multi-Head Attention module."""
+
+ def __init__(self, n_head, d_model, d_k, d_v, dropout=0):
+ super().__init__()
+
+ self.n_head = n_head
+ self.d_k = d_k
+ self.d_v = d_v
+
+ self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
+ self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
+ self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
+ self.fc = nn.Linear(n_head * d_v, d_model, bias=False)
+
+ nn.init.eye_(self.w_qs.weight)
+ nn.init.eye_(self.w_ks.weight)
+ nn.init.eye_(self.w_vs.weight)
+ nn.init.eye_(self.fc.weight)
+ self.attention = ScaledDotProductAttention(temperature=0.5)
+
+ self.dropout = nn.Dropout(dropout)
+ self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
+
+ def forward(self, q, k, v, mask=None):
+ d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
+ sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
+
+ # Pass through the pre-attention projection: b x lq x (n*dv)
+ # Separate different heads: b x lq x n x dv
+ q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
+ k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
+ v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
+
+ # Transpose for attention dot product: b x n x lq x dv
+ q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
+
+ if mask is not None:
+ mask = mask.unsqueeze(1) # For head axis broadcasting.
+
+ q, attn = self.attention(q, k, v, mask=mask)
+
+ # Transpose to move the head dimension back: b x lq x n x dv
+ # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
+ q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
+ q = self.dropout(self.fc(q))
+
+ # Just return weighted sum, do not apply residule
+ attn = torch.mean(attn, 1)
+ return q, attn
diff --git a/ocl/models/__init__.py b/ocl/models/__init__.py
new file mode 100644
index 0000000..e482a3e
--- /dev/null
+++ b/ocl/models/__init__.py
@@ -0,0 +1,7 @@
+"""Models defined in code."""
+# from ocl.models.sa_detr import SA_DETR
+# from ocl.models.savi import SAVi
+# from ocl.models.savi_with_memory import SAVi_mem
+# from ocl.models.savi_with_memory import SAVi_mem
+
+# __all__ = ["SAVi", "SAVi_mem", "SA_DETR"]
diff --git a/ocl/models/image_grouping.py b/ocl/models/image_grouping.py
new file mode 100644
index 0000000..4af028b
--- /dev/null
+++ b/ocl/models/image_grouping.py
@@ -0,0 +1,71 @@
+from typing import Any, Dict
+
+from torch import nn
+
+from ocl.path_defaults import VIDEO
+from ocl.utils.trees import get_tree_element
+
+
+class GroupingImg(nn.Module):
+ def __init__(
+ self,
+ conditioning: nn.Module,
+ feature_extractor: nn.Module,
+ perceptual_grouping: nn.Module,
+ object_decoder: nn.Module,
+ masks_as_image = None,
+ decoder_mode = "MLP",
+
+ ):
+ super().__init__()
+ self.conditioning = conditioning
+ self.feature_extractor = feature_extractor
+ self.perceptual_grouping = perceptual_grouping
+ self.object_decoder = object_decoder
+ self.masks_as_image = masks_as_image
+ self.decoder_mode = decoder_mode
+
+ def forward(self, inputs: Dict[str, Any]):
+ outputs = inputs
+ video = get_tree_element(inputs, VIDEO.split("."))
+ video.shape
+
+ # feature extraction
+ features = self.feature_extractor(video=video)
+ outputs["feature_extractor"] = features
+
+ # slot initialization
+ batch_size = video.shape[0]
+ conditioning = self.conditioning(batch_size=batch_size)
+ outputs["conditioning"] = conditioning
+
+ # slot computation
+ perceptual_grouping_output = self.perceptual_grouping(
+ extracted_features=features, conditioning=conditioning
+ )
+ outputs["perceptual_grouping"] = perceptual_grouping_output
+
+ # slot decoding
+ object_features = get_tree_element(outputs, "perceptual_grouping.objects".split("."))
+ masks = get_tree_element(outputs, "perceptual_grouping.feature_attributions".split("."))
+ target = get_tree_element(outputs, "feature_extractor.features".split("."))
+ image = get_tree_element(outputs, "input.image".split("."))
+ empty_object = None
+
+ if self.decoder_mode == "MLP":
+ decoder_output = self.object_decoder(object_features=object_features,
+ target=target,
+ image = image)
+ elif self.decoder_mode == "Transformer":
+ decoder_output = self.object_decoder(object_features=object_features,
+ masks=masks,
+ target=target,
+ image=image,
+ empty_objects = None)
+ else:
+ raise RuntimeError
+
+ outputs["object_decoder"] = decoder_output
+ outputs["masks_as_image"]= self.masks_as_image(tensor = get_tree_element(outputs, "object_decoder.masks".split(".")))
+
+ return outputs
diff --git a/ocl/models/image_grouping_adaslot.py b/ocl/models/image_grouping_adaslot.py
new file mode 100644
index 0000000..b639e43
--- /dev/null
+++ b/ocl/models/image_grouping_adaslot.py
@@ -0,0 +1,77 @@
+from typing import Any, Dict
+
+from torch import nn
+
+from ocl.path_defaults import VIDEO
+from ocl.utils.trees import get_tree_element
+import torch
+
+class GroupingImgGumbel(nn.Module):
+ def __init__(
+ self,
+ conditioning: nn.Module,
+ feature_extractor: nn.Module,
+ perceptual_grouping: nn.Module,
+ object_decoder: nn.Module,
+ masks_as_image = None,
+ decoder_mode = "MLP",
+ ):
+ super().__init__()
+ self.conditioning = conditioning
+ self.feature_extractor = feature_extractor
+ self.perceptual_grouping = perceptual_grouping
+ self.object_decoder = object_decoder
+ self.masks_as_image = masks_as_image
+ self.decoder_mode = decoder_mode
+ object_dim = self.conditioning.object_dim
+
+ def forward(self, inputs: Dict[str, Any]):
+ outputs = inputs
+ video = get_tree_element(inputs, VIDEO.split("."))
+ video.shape
+
+ # feature extraction
+ features = self.feature_extractor(video=video)
+ outputs["feature_extractor"] = features
+
+ # slot initialization
+ batch_size = video.shape[0]
+ conditioning = self.conditioning(batch_size=batch_size)
+ outputs["conditioning"] = conditioning
+
+ # slot computation
+ perceptual_grouping_output = self.perceptual_grouping(
+ extracted_features=features, conditioning=conditioning
+ )
+ outputs["perceptual_grouping"] = perceptual_grouping_output
+ outputs["hard_keep_decision"] = perceptual_grouping_output["hard_keep_decision"]
+ outputs["slots_keep_prob"] = perceptual_grouping_output["slots_keep_prob"]
+
+ ##
+ object_features, hard_keep_decision = perceptual_grouping_output["objects"], perceptual_grouping_output["hard_keep_decision"] # (b * t, s, d), (b * t, s, n)
+ # slot decoding
+ # object_features = get_tree_element(outputs, "perceptual_grouping.objects".split("."))
+ masks = get_tree_element(outputs, "perceptual_grouping.feature_attributions".split("."))
+ target = get_tree_element(outputs, "feature_extractor.features".split("."))
+ image = get_tree_element(outputs, "input.image".split("."))
+ empty_object = None
+
+ if self.decoder_mode == "MLP":
+ decoder_output = self.object_decoder(object_features=object_features,
+ target=target,
+ image = image,
+ left_mask = hard_keep_decision)
+ elif self.decoder_mode == "Transformer":
+ decoder_output = self.object_decoder(object_features=object_features,
+ masks=masks,
+ target=target,
+ image=image,
+ empty_objects = None,
+ left_mask = hard_keep_decision)
+ else:
+ raise RuntimeError
+
+ outputs["object_decoder"] = decoder_output
+ if not self.masks_as_image is None:
+ outputs["masks_as_image"]= self.masks_as_image(tensor = get_tree_element(outputs, "object_decoder.masks".split(".")))
+ return outputs
\ No newline at end of file
diff --git a/ocl/models/image_grouping_adaslot_pixel.py b/ocl/models/image_grouping_adaslot_pixel.py
new file mode 100644
index 0000000..e724c6b
--- /dev/null
+++ b/ocl/models/image_grouping_adaslot_pixel.py
@@ -0,0 +1,58 @@
+from typing import Any, Dict
+
+from torch import nn
+
+from ocl.path_defaults import VIDEO
+from ocl.utils.trees import get_tree_element
+import torch
+
+class GroupingImgGumbel(nn.Module):
+ def __init__(
+ self,
+ conditioning: nn.Module,
+ feature_extractor: nn.Module,
+ perceptual_grouping: nn.Module,
+ object_decoder: nn.Module,
+ masks_as_image = None,
+ ):
+ super().__init__()
+ self.conditioning = conditioning
+ self.feature_extractor = feature_extractor
+ self.perceptual_grouping = perceptual_grouping
+ self.object_decoder = object_decoder
+ self.masks_as_image = masks_as_image
+ object_dim = self.conditioning.object_dim
+
+ def forward(self, inputs: Dict[str, Any]):
+ outputs = inputs
+ video = get_tree_element(inputs, VIDEO.split("."))
+ video.shape
+
+ # feature extraction
+ features = self.feature_extractor(video=video)
+ outputs["feature_extractor"] = features
+
+ # slot initialization
+ batch_size = video.shape[0]
+ conditioning = self.conditioning(batch_size=batch_size)
+ outputs["conditioning"] = conditioning
+
+ # slot computation
+ perceptual_grouping_output = self.perceptual_grouping(
+ extracted_features=features, conditioning=conditioning
+ )
+ outputs["perceptual_grouping"] = perceptual_grouping_output
+ outputs["hard_keep_decision"] = perceptual_grouping_output["hard_keep_decision"]
+ outputs["slots_keep_prob"] = perceptual_grouping_output["slots_keep_prob"]
+
+ ##
+ object_features, hard_keep_decision = perceptual_grouping_output["objects"], perceptual_grouping_output["hard_keep_decision"] # (b * t, s, d), (b * t, s, n)
+ # slot decoding
+ # object_features = get_tree_element(outputs, "perceptual_grouping.objects".split("."))
+ decoder_output = self.object_decoder(object_features=object_features,
+ left_mask = hard_keep_decision)
+
+ outputs["object_decoder"] = decoder_output
+ if not self.masks_as_image is None:
+ outputs["masks_as_image"]= self.masks_as_image(tensor = get_tree_element(outputs, "object_decoder.masks".split(".")))
+ return outputs
\ No newline at end of file
diff --git a/ocl/neural_networks/__init__.py b/ocl/neural_networks/__init__.py
new file mode 100644
index 0000000..a834713
--- /dev/null
+++ b/ocl/neural_networks/__init__.py
@@ -0,0 +1,13 @@
+from ocl.neural_networks.convenience import (
+ build_mlp,
+ build_transformer_decoder,
+ build_transformer_encoder,
+ build_two_layer_mlp,
+)
+
+__all__ = [
+ "build_mlp",
+ "build_transformer_decoder",
+ "build_transformer_encoder",
+ "build_two_layer_mlp",
+]
diff --git a/ocl/neural_networks/convenience.py b/ocl/neural_networks/convenience.py
new file mode 100644
index 0000000..418787a
--- /dev/null
+++ b/ocl/neural_networks/convenience.py
@@ -0,0 +1,158 @@
+"""Convenience functions for the construction neural networks using config."""
+from typing import Callable, List, Optional, Union
+
+from torch import nn
+
+from ocl.neural_networks.extensions import TransformerDecoderWithAttention
+from ocl.neural_networks.wrappers import Residual
+
+
+class ReLUSquared(nn.Module):
+ def __init__(self, inplace=False):
+ super().__init__()
+ self.inplace = inplace
+
+ def forward(self, x):
+ return nn.functional.relu(x, inplace=self.inplace) ** 2
+
+
+def get_activation_fn(name: str, inplace: bool = True, leaky_relu_slope: Optional[float] = None):
+ if callable(name):
+ return name
+
+ name = name.lower()
+ if name == "relu":
+ return nn.ReLU(inplace=inplace)
+ elif name == "relu_squared":
+ return ReLUSquared(inplace=inplace)
+ elif name == "leaky_relu":
+ if leaky_relu_slope is None:
+ raise ValueError("Slope of leaky ReLU was not defined")
+ return nn.LeakyReLU(leaky_relu_slope, inplace=inplace)
+ elif name == "tanh":
+ return nn.Tanh()
+ elif name == "sigmoid":
+ return nn.Sigmoid()
+ elif name == "identity":
+ return nn.Identity()
+ else:
+ raise ValueError(f"Unknown activation function {name}")
+
+
+def build_mlp(
+ input_dim: int,
+ output_dim: int,
+ features: List[int],
+ activation_fn: Union[str, Callable] = "relu",
+ final_activation_fn: Optional[Union[str, Callable]] = None,
+ initial_layer_norm: bool = False,
+ residual: bool = False,
+) -> nn.Sequential:
+ layers = []
+ current_dim = input_dim
+ if initial_layer_norm:
+ layers.append(nn.LayerNorm(current_dim))
+
+ for n_features in features:
+ layers.append(nn.Linear(current_dim, n_features))
+ nn.init.zeros_(layers[-1].bias)
+ layers.append(get_activation_fn(activation_fn))
+ current_dim = n_features
+
+ layers.append(nn.Linear(current_dim, output_dim))
+ nn.init.zeros_(layers[-1].bias)
+ if final_activation_fn is not None:
+ layers.append(get_activation_fn(final_activation_fn))
+
+ if residual:
+ return Residual(nn.Sequential(*layers))
+ return nn.Sequential(*layers)
+
+
+def build_two_layer_mlp(
+ input_dim, output_dim, hidden_dim, initial_layer_norm: bool = False, residual: bool = False
+):
+ """Build a two layer MLP, with optional initial layer norm.
+
+ Separate class as this type of construction is used very often for slot attention and
+ transformers.
+ """
+ return build_mlp(
+ input_dim, output_dim, [hidden_dim], initial_layer_norm=initial_layer_norm, residual=residual
+ )
+
+
+def build_transformer_encoder(
+ input_dim: int,
+ output_dim: int,
+ n_layers: int,
+ n_heads: int,
+ hidden_dim: Optional[int] = None,
+ dropout: float = 0.0,
+ activation_fn: Union[str, Callable] = "relu",
+ layer_norm_eps: float = 1e-5,
+ use_output_transform: bool = True,
+):
+ if hidden_dim is None:
+ hidden_dim = 4 * input_dim
+
+ layers = []
+ for _ in range(n_layers):
+ layers.append(
+ nn.TransformerEncoderLayer(
+ d_model=input_dim,
+ nhead=n_heads,
+ dim_feedforward=hidden_dim,
+ dropout=dropout,
+ activation=activation_fn,
+ layer_norm_eps=layer_norm_eps,
+ batch_first=True,
+ norm_first=True,
+ )
+ )
+
+ if use_output_transform:
+ layers.append(nn.LayerNorm(input_dim, eps=layer_norm_eps))
+ output_transform = nn.Linear(input_dim, output_dim, bias=True)
+ nn.init.xavier_uniform_(output_transform.weight)
+ nn.init.zeros_(output_transform.bias)
+ layers.append(output_transform)
+
+ return nn.Sequential(*layers)
+
+
+def build_transformer_decoder(
+ input_dim: int,
+ output_dim: int,
+ n_layers: int,
+ n_heads: int,
+ hidden_dim: Optional[int] = None,
+ dropout: float = 0.0,
+ activation_fn: Union[str, Callable] = "relu",
+ layer_norm_eps: float = 1e-5,
+ return_attention_weights: bool = False,
+ attention_weight_type: Union[int, str] = -1,
+):
+ if hidden_dim is None:
+ hidden_dim = 4 * input_dim
+
+ decoder_layer = nn.TransformerDecoderLayer(
+ d_model=input_dim,
+ nhead=n_heads,
+ dim_feedforward=hidden_dim,
+ dropout=dropout,
+ activation=activation_fn,
+ layer_norm_eps=layer_norm_eps,
+ batch_first=True,
+ norm_first=True,
+ )
+
+ if return_attention_weights:
+ return TransformerDecoderWithAttention(
+ decoder_layer,
+ n_layers,
+ return_attention_weights=True,
+ attention_weight_type=attention_weight_type,
+ )
+ else:
+ return nn.TransformerDecoder(decoder_layer, n_layers)
diff --git a/ocl/neural_networks/extensions.py b/ocl/neural_networks/extensions.py
new file mode 100644
index 0000000..e840b9a
--- /dev/null
+++ b/ocl/neural_networks/extensions.py
@@ -0,0 +1,109 @@
+"""Extensions of existing layers to implement additional functionality."""
+from typing import Optional, Union
+
+import torch
+from torch import nn
+
+
+class TransformerDecoderWithAttention(nn.TransformerDecoder):
+ """Modified nn.TransformerDecoder class that returns attention weights over memory."""
+
+ def __init__(
+ self,
+ decoder_layer,
+ num_layers,
+ norm=None,
+ return_attention_weights=False,
+ attention_weight_type: Union[int, str] = "mean",
+ ):
+ super(TransformerDecoderWithAttention, self).__init__(decoder_layer, num_layers, norm)
+
+ if return_attention_weights:
+ self.attention_hooks = []
+ for layer in self.layers:
+ self.attention_hooks.append(self._prepare_layer(layer))
+ else:
+ self.attention_hooks = None
+
+ if isinstance(attention_weight_type, int):
+ if attention_weight_type >= num_layers or attention_weight_type < -num_layers:
+ raise ValueError(
+ f"Index {attention_weight_type} exceeds number of layers {num_layers}"
+ )
+ elif attention_weight_type != "mean":
+ raise ValueError("`weights` needs to be a number or 'mean'.")
+ self.weights = attention_weight_type
+
+ def _prepare_layer(self, layer):
+ assert isinstance(layer, nn.TransformerDecoderLayer)
+
+ def _mha_block(self, x, mem, attn_mask, key_padding_mask):
+ x = self.multihead_attn(
+ x,
+ mem,
+ mem,
+ attn_mask=attn_mask,
+ key_padding_mask=key_padding_mask,
+ need_weights=True,
+ )[0]
+ return self.dropout2(x)
+
+ # Patch _mha_block method to compute attention weights
+ layer._mha_block = _mha_block.__get__(layer, nn.TransformerDecoderLayer)
+
+ class AttentionHook:
+ def __init__(self):
+ self._attention = None
+
+ def pop(self) -> torch.Tensor:
+ assert self._attention is not None, "Forward was not called yet!"
+ attention = self._attention
+ self._attention = None
+ return attention
+
+ def __call__(self, module, inp, outp):
+ self._attention = outp[1]
+
+ hook = AttentionHook()
+ layer.multihead_attn.register_forward_hook(hook)
+ return hook
+
+ def forward(
+ self,
+ tgt: torch.Tensor,
+ memory: torch.Tensor,
+ tgt_mask: Optional[torch.Tensor] = None,
+ memory_mask: Optional[torch.Tensor] = None,
+ tgt_key_padding_mask: Optional[torch.Tensor] = None,
+ memory_key_padding_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ output = tgt
+
+ for mod in self.layers:
+ output = mod(
+ output,
+ memory,
+ tgt_mask=tgt_mask,
+ memory_mask=memory_mask,
+ tgt_key_padding_mask=tgt_key_padding_mask,
+ memory_key_padding_mask=memory_key_padding_mask,
+ )
+
+ if self.norm is not None:
+ output = self.norm(output)
+
+ if self.attention_hooks is not None:
+ attentions = []
+ for hook in self.attention_hooks:
+ attentions.append(hook.pop())
+
+ if self.weights == "mean":
+ attentions = torch.stack(attentions, dim=-1)
+ # Take mean over all layers
+ attention = attentions.mean(dim=-1)
+ else:
+ attention = attentions[self.weights]
+
+ return output, attention.transpose(1, 2)
+ else:
+ return output
diff --git a/ocl/neural_networks/feature_pyramid_networks.py b/ocl/neural_networks/feature_pyramid_networks.py
new file mode 100644
index 0000000..a59a580
--- /dev/null
+++ b/ocl/neural_networks/feature_pyramid_networks.py
@@ -0,0 +1,117 @@
+from typing import Optional
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from ocl.utils.routing import RoutableMixin
+
+
+class FeaturePyramidDecoder(nn.Module, RoutableMixin):
+ def __init__(
+ self,
+ slot_dim: int,
+ feature_dim: int,
+ mask_path: Optional[str] = None,
+ slots_path: Optional[str] = None,
+ features_path: Optional[str] = None,
+ ):
+ nn.Module.__init__(self)
+ RoutableMixin.__init__(
+ self,
+ {
+ "slots": slots_path,
+ "mask": mask_path,
+ "features": features_path,
+ },
+ )
+
+ inter_dims = [slot_dim, slot_dim // 2, slot_dim // 4, slot_dim // 8, slot_dim // 16]
+ # Depth dimension is slot dimension, no padding there and kernel size 1.
+ self.lay1 = torch.nn.Conv3d(inter_dims[0], inter_dims[0], (1, 3, 3), padding=(0, 1, 1))
+ self.gn1 = torch.nn.GroupNorm(8, inter_dims[0])
+ self.lay2 = torch.nn.Conv3d(inter_dims[0], inter_dims[1], (1, 3, 3), padding=(0, 1, 1))
+ self.gn2 = torch.nn.GroupNorm(8, inter_dims[1])
+ self.lay3 = torch.nn.Conv3d(inter_dims[1], inter_dims[2], (1, 3, 3), padding=(0, 1, 1))
+ self.gn3 = torch.nn.GroupNorm(8, inter_dims[2])
+ self.lay4 = torch.nn.Conv3d(inter_dims[2], inter_dims[3], (1, 3, 3), padding=(0, 1, 1))
+ self.gn4 = torch.nn.GroupNorm(8, inter_dims[3])
+ self.lay5 = torch.nn.Conv3d(inter_dims[3], inter_dims[4], (1, 3, 3), padding=(0, 1, 1))
+ self.gn5 = torch.nn.GroupNorm(8, inter_dims[4])
+ self.out_lay = torch.nn.ConvTranspose3d(
+ inter_dims[4],
+ 1,
+ stride=(1, 2, 2),
+ kernel_size=(1, 3, 3),
+ padding=(0, 1, 1),
+ output_padding=(0, 1, 1),
+ )
+
+ upsampled_dim = feature_dim // 8
+ self.upsampling = nn.ConvTranspose2d(
+ feature_dim, upsampled_dim, kernel_size=8, stride=8
+ ) # 112 x 112
+ self.adapter1 = nn.Conv2d(
+ upsampled_dim, inter_dims[0], kernel_size=5, padding=2, stride=8
+ ) # Should downsample 112 to 14
+ self.adapter2 = nn.Conv2d(
+ upsampled_dim, inter_dims[1], kernel_size=5, padding=2, stride=4
+ ) # 28x28
+ self.adapter3 = nn.Conv2d(
+ upsampled_dim, inter_dims[2], kernel_size=5, padding=2, stride=2
+ ) # 56 x 56
+ self.adapter4 = nn.Conv2d(
+ upsampled_dim, inter_dims[3], kernel_size=5, padding=2, stride=1
+ ) # 112 x 112
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv3d):
+ nn.init.kaiming_uniform_(m.weight, a=1)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, slots: torch.Tensor, mask: torch.Tensor, features: torch.Tensor):
+ # Bring features into image format with channels first
+ features = features.unflatten(1, (14, 14)).permute(0, 3, 1, 2)
+ mask = mask.unflatten(-1, (14, 14))
+ # Use depth dimension for slots
+ x = slots.transpose(1, 2)[..., None, None] * mask.unsqueeze(1)
+ bs, n_channels, n_slots, width, height = x.shape
+
+ upsampled_features = self.upsampling(features)
+
+ # Add fake depth dimension for broadcasting and upsample representation.
+ x = self.lay1(x) + self.adapter1(upsampled_features).unsqueeze(2)
+ x = self.gn1(x)
+ x = F.relu(x)
+ x = self.lay2(x)
+ x = self.gn2(x)
+ x = F.relu(x)
+
+ cur_fpn = self.adapter2(upsampled_features)
+ # Add fake depth dimension for broadcasting and upsample representation.
+ x = cur_fpn.unsqueeze(2) + F.interpolate(
+ x, size=(n_slots,) + cur_fpn.shape[-2:], mode="nearest"
+ )
+ x = self.lay3(x)
+ x = self.gn3(x)
+ x = F.relu(x)
+
+ cur_fpn = self.adapter3(upsampled_features)
+ x = cur_fpn.unsqueeze(2) + F.interpolate(
+ x, size=(n_slots,) + cur_fpn.shape[-2:], mode="nearest"
+ )
+ x = self.lay4(x)
+ x = self.gn4(x)
+ x = F.relu(x)
+
+ cur_fpn = self.adapter4(upsampled_features)
+ x = cur_fpn.unsqueeze(2) + F.interpolate(
+ x, size=(n_slots,) + cur_fpn.shape[-2:], mode="nearest"
+ )
+ x = self.lay5(x)
+ x = self.gn5(x)
+ x = F.relu(x)
+
+ # Squeeze channel dimension.
+ x = self.out_lay(x).squeeze(1).softmax(1)
+ return x
diff --git a/ocl/neural_networks/positional_embedding.py b/ocl/neural_networks/positional_embedding.py
new file mode 100644
index 0000000..44bc83e
--- /dev/null
+++ b/ocl/neural_networks/positional_embedding.py
@@ -0,0 +1,63 @@
+"""Implementation of different positional embeddings."""
+import torch
+from torch import nn
+
+
+class SoftPositionEmbed(nn.Module):
+ """Embeding of positions using convex combination of learnable tensors.
+
+ This assumes that the input positions are between 0 and 1.
+ """
+
+ def __init__(
+ self, n_spatial_dims: int, feature_dim: int, cnn_channel_order=False, savi_style=False
+ ):
+ """__init__.
+
+ Args:
+ n_spatial_dims (int): Number of spatial dimensions.
+ feature_dim (int): Dimensionality of the input features.
+ cnn_channel_order (bool): Assume features are in CNN channel order (i.e. C x H x W).
+ savi_style (bool): Use savi style positional encoding, where positions are normalized
+ between -1 and 1 and a single dense layer is used for embedding.
+ """
+ super().__init__()
+ self.savi_style = savi_style
+ n_features = n_spatial_dims if savi_style else 2 * n_spatial_dims
+ self.dense = nn.Linear(in_features=n_features, out_features=feature_dim)
+ self.cnn_channel_order = cnn_channel_order
+
+ def forward(self, inputs: torch.Tensor, positions: torch.Tensor):
+ if self.savi_style:
+ # Rescale positional encoding to -1 to 1
+ positions = (positions - 0.5) * 2
+ else:
+ positions = torch.cat([positions, 1 - positions], axis=-1)
+ emb_proj = self.dense(positions)
+ if self.cnn_channel_order:
+ emb_proj = emb_proj.permute(*range(inputs.ndim - 3), -1, -3, -2)
+ return inputs + emb_proj
+
+
+class LearnedAdditivePositionalEmbed(nn.Module):
+ """Add positional encoding as in SLATE."""
+
+ def __init__(self, max_len, d_model, dropout=0.0):
+ super().__init__()
+ self.dropout = nn.Dropout(dropout)
+ self.pe = nn.Parameter(torch.zeros(1, max_len, d_model), requires_grad=True)
+ nn.init.trunc_normal_(self.pe)
+
+ def forward(self, input):
+ T = input.shape[1]
+ return self.dropout(input + self.pe[:, :T])
+
+
+class DummyPositionEmbed(nn.Module):
+ """Embedding that just passes through inputs without adding any positional embeddings."""
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, inputs: torch.Tensor, positions: torch.Tensor):
+ return inputs
diff --git a/ocl/neural_networks/slate.py b/ocl/neural_networks/slate.py
new file mode 100644
index 0000000..d0c0417
--- /dev/null
+++ b/ocl/neural_networks/slate.py
@@ -0,0 +1,55 @@
+"""Neural networks used for the implemenation of SLATE."""
+import torch
+from torch import nn
+
+
+class OneHotDictionary(nn.Module):
+ def __init__(self, vocab_size: int, emb_size: int):
+ super().__init__()
+ self.dictionary = nn.Embedding(vocab_size, emb_size)
+
+ def forward(self, x):
+ tokens = torch.argmax(x, dim=-1) # batch_size x N
+ token_embs = self.dictionary(tokens) # batch_size x N x emb_size
+ return token_embs
+
+
+class Conv2dBlockWithGroupNorm(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias=True,
+ padding_mode="zeros",
+ weight_init="xavier",
+ ):
+ super().__init__()
+ self.conv2d = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ dilation,
+ groups,
+ bias,
+ padding_mode,
+ )
+
+ if weight_init == "kaiming":
+ nn.init.kaiming_uniform_(self.conv2d.weight, nonlinearity="relu")
+ else:
+ nn.init.xavier_uniform_(self.conv2d.weight)
+
+ if bias:
+ nn.init.zeros_(self.conv2d.bias)
+ self.group_norm = nn.GroupNorm(1, out_channels)
+
+ def forward(self, x):
+ x = self.conv2d(x)
+ return nn.functional.relu(self.group_norm(x))
diff --git a/ocl/neural_networks/wrappers.py b/ocl/neural_networks/wrappers.py
new file mode 100644
index 0000000..8f544ac
--- /dev/null
+++ b/ocl/neural_networks/wrappers.py
@@ -0,0 +1,33 @@
+"""Wrapper modules with allow the introduction or residuals or the combination of other modules."""
+from torch import nn
+
+
+class Residual(nn.Module):
+ def __init__(self, module: nn.Module):
+ super().__init__()
+ self.module = module
+
+ def forward(self, inputs):
+ return inputs + self.module(inputs)
+
+
+class Sequential(nn.Module):
+ """Extended sequential module that supports multiple inputs and outputs to layers.
+
+ This allows a stack of layers where for example the first layer takes two inputs and only has
+ a single output or where a layer has multiple outputs and the downstream layer takes multiple
+ inputs.
+ """
+
+ def __init__(self, *layers):
+ super().__init__()
+ self.layers = nn.ModuleList(layers)
+
+ def forward(self, *inputs):
+ outputs = inputs
+ for layer in self.layers:
+ if isinstance(outputs, (tuple, list)):
+ outputs = layer(*outputs)
+ else:
+ outputs = layer(outputs)
+ return outputs
diff --git a/ocl/path_defaults.py b/ocl/path_defaults.py
new file mode 100644
index 0000000..9b7064d
--- /dev/null
+++ b/ocl/path_defaults.py
@@ -0,0 +1,21 @@
+"""Default paths for different types of inputs.
+
+These are only defined for convenience and can also be overwritten using the appropriate *_path
+constructor variables of RoutableMixin subclasses.
+"""
+MODEL = "model"
+INPUT = "input"
+VIDEO = f"{INPUT}.image"
+TEXT = f"{INPUT}.caption"
+BATCH_SIZE = f"{INPUT}.batch_size"
+BOX = f"{INPUT}.instance_bbox"
+MASK = f"{INPUT}.mask"
+ID = f"{INPUT}.instance_id"
+GLOBAL_STEP = "global_step"
+FEATURES = "feature_extractor"
+CONDITIONING = "conditioning"
+# TODO(hornmax): Currently decoders are nested in the task and accept PerceptualGroupingOutput as
+# input. In the future this will change and decoders should just be regular parts of the model.
+OBJECTS = "perceptual_grouping.objects"
+FEATURE_ATTRIBUTIONS = "perceptual_grouping.feature_attributions"
+OBJECT_DECODER = "object_decoder"
diff --git a/ocl/perceptual_grouping.py b/ocl/perceptual_grouping.py
new file mode 100644
index 0000000..13699b5
--- /dev/null
+++ b/ocl/perceptual_grouping.py
@@ -0,0 +1,479 @@
+"""Implementations of perceptual grouping algorithms."""
+import math
+from typing import Any, Dict, Optional
+
+import numpy
+import torch
+from sklearn import cluster
+from torch import nn
+import torch
+from ocl import base, path_defaults
+from ocl.utils.routing import RoutableMixin
+
+
+class SlotAttention(nn.Module):
+ """Implementation of SlotAttention.
+
+ Based on the slot attention implementation of Phil Wang available at:
+ https://github.com/lucidrains/slot-attention
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ feature_dim: int,
+ kvq_dim: Optional[int] = None,
+ n_heads: int = 1,
+ iters: int = 3,
+ eps: float = 1e-8,
+ ff_mlp: Optional[nn.Module] = None,
+ use_projection_bias: bool = False,
+ use_implicit_differentiation: bool = False,
+ ):
+ super().__init__()
+ self.dim = dim
+ self.n_heads = n_heads
+ self.iters = iters
+ self.eps = eps
+ self.use_implicit_differentiation = use_implicit_differentiation
+
+ if kvq_dim is None:
+ self.kvq_dim = dim
+ else:
+ self.kvq_dim = kvq_dim
+
+ if self.kvq_dim % self.n_heads != 0:
+ raise ValueError("Key, value, query dimensions must be divisible by number of heads.")
+ self.dims_per_head = self.kvq_dim // self.n_heads
+ self.scale = self.dims_per_head**-0.5
+
+ self.to_q = nn.Linear(dim, self.kvq_dim, bias=use_projection_bias)
+ self.to_k = nn.Linear(feature_dim, self.kvq_dim, bias=use_projection_bias)
+ self.to_v = nn.Linear(feature_dim, self.kvq_dim, bias=use_projection_bias)
+
+ self.gru = nn.GRUCell(self.kvq_dim, dim)
+
+ self.norm_input = nn.LayerNorm(feature_dim)
+ self.norm_slots = nn.LayerNorm(dim)
+ self.ff_mlp = ff_mlp
+
+ def step(self, slots, k, v, masks=None):
+ bs, n_slots, _ = slots.shape
+ slots_prev = slots
+
+ slots = self.norm_slots(slots)
+ q = self.to_q(slots).view(bs, n_slots, self.n_heads, self.dims_per_head)
+
+ dots = torch.einsum("bihd,bjhd->bihj", q, k) * self.scale
+ if masks is not None:
+ # Masked slots should not take part in the competition for features. By replacing their
+ # dot-products with -inf, their attention values will become zero within the softmax.
+ dots.masked_fill_(masks.to(torch.bool).view(bs, n_slots, 1, 1), float("-inf"))
+
+ attn = dots.flatten(1, 2).softmax(dim=1) # Take softmax over slots and heads
+ attn = attn.view(bs, n_slots, self.n_heads, -1)
+ attn_before_reweighting = attn
+ attn = attn + self.eps
+ attn = attn / attn.sum(dim=-1, keepdim=True)
+
+ updates = torch.einsum("bjhd,bihj->bihd", v, attn)
+
+ slots = self.gru(updates.reshape(-1, self.kvq_dim), slots_prev.reshape(-1, self.dim))
+
+ slots = slots.reshape(bs, -1, self.dim)
+
+ if self.ff_mlp:
+ slots = self.ff_mlp(slots)
+
+ return slots, attn_before_reweighting.mean(dim=2)
+
+ def iterate(self, slots, k, v, masks=None):
+ for _ in range(self.iters):
+ slots, attn = self.step(slots, k, v, masks)
+ return slots, attn
+
+ def forward(
+ self, inputs: torch.Tensor, conditioning: torch.Tensor, masks: Optional[torch.Tensor] = None
+ ):
+ b, n, d = inputs.shape
+ slots = conditioning
+
+ inputs = self.norm_input(inputs)
+ k = self.to_k(inputs).view(b, n, self.n_heads, self.dims_per_head)
+ v = self.to_v(inputs).view(b, n, self.n_heads, self.dims_per_head)
+
+ if self.use_implicit_differentiation:
+ slots, attn = self.iterate(slots, k, v, masks)
+ slots, attn = self.step(slots.detach(), k, v, masks)
+ else:
+ slots, attn = self.iterate(slots, k, v, masks)
+
+ return slots, attn
+
+
+class SlotAttentionGrouping(base.PerceptualGrouping, RoutableMixin):
+ """Implementation of SlotAttention for perceptual grouping.
+
+ Args:
+ feature_dim: Dimensionality of features to slot attention (after positional encoding).
+ object_dim: Dimensionality of slots.
+ kvq_dim: Dimensionality after projecting to keys, values, and queries. If `None`,
+ `object_dim` is used.
+ n_heads: Number of heads slot attention uses.
+ iters: Number of slot attention iterations.
+ eps: Epsilon in slot attention.
+ ff_mlp: Optional module applied slot-wise after GRU update.
+ positional_embedding: Optional module applied to the features before slot attention, adding
+ positional encoding.
+ use_projection_bias: Whether to use biases in key, value, query projections.
+ use_implicit_differentiation: Whether to use implicit differentiation trick. If true,
+ performs one more iteration of slot attention that is used for the gradient step after
+ `iters` iterations of slot attention without gradients. Faster and more memory efficient
+ than the standard version, but can not backpropagate gradients to the conditioning input.
+ input_dim: Dimensionality of features before positional encoding is applied. Specifying this
+ is optional but can be convenient to structure configurations.
+ """
+
+ def __init__(
+ self,
+ feature_dim: int,
+ object_dim: int,
+ kvq_dim: Optional[int] = None,
+ n_heads: int = 1,
+ iters: int = 3,
+ eps: float = 1e-8,
+ ff_mlp: Optional[nn.Module] = None,
+ positional_embedding: Optional[nn.Module] = None,
+ use_projection_bias: bool = False,
+ use_implicit_differentiation: bool = False,
+ use_empty_slot_for_masked_slots: bool = False,
+ input_dim: Optional[int] = None,
+ feature_path: Optional[str] = path_defaults.FEATURES,
+ conditioning_path: Optional[str] = path_defaults.CONDITIONING,
+ slot_mask_path: Optional[str] = None,
+ ):
+ base.PerceptualGrouping.__init__(self)
+ RoutableMixin.__init__(
+ self,
+ {
+ "extracted_features": feature_path,
+ "conditioning": conditioning_path,
+ "slot_masks": slot_mask_path,
+ },
+ )
+
+ self._object_dim = object_dim
+ self.slot_attention = SlotAttention(
+ dim=object_dim,
+ feature_dim=feature_dim,
+ kvq_dim=kvq_dim,
+ n_heads=n_heads,
+ iters=iters,
+ eps=eps,
+ ff_mlp=ff_mlp,
+ use_projection_bias=use_projection_bias,
+ use_implicit_differentiation=use_implicit_differentiation,
+ )
+
+ self.positional_embedding = positional_embedding
+
+ if use_empty_slot_for_masked_slots:
+ if slot_mask_path is None:
+ raise ValueError("Need `slot_mask_path` for `use_empty_slot_for_masked_slots`")
+ self.empty_slot = nn.Parameter(torch.randn(object_dim) * object_dim**-0.5)
+ else:
+ self.empty_slot = None
+
+ @property
+ def object_dim(self):
+ return self._object_dim
+
+ @RoutableMixin.route
+ def forward(
+ self,
+ extracted_features: base.FeatureExtractorOutput,
+ conditioning: base.ConditioningOutput,
+ slot_masks: Optional[torch.Tensor] = None,
+ ):
+ if self.positional_embedding:
+ features = self.positional_embedding(
+ extracted_features.features, extracted_features.positions
+ )
+ else:
+ features = extracted_features.features
+
+ slots, attn = self.slot_attention(features, conditioning, slot_masks)
+
+ if slot_masks is not None and self.empty_slot is not None:
+ slots[slot_masks] = self.empty_slot.to(dtype=slots.dtype)
+
+ return base.PerceptualGroupingOutput(slots, feature_attributions=attn, is_empty=slot_masks)
+
+import torch.nn.functional as F
+def sample_slot_lower_bound(A, lower_bound = 1):
+ """
+ A: [b, k] a batch of slot mask
+ 0 mean drop, 1 means left
+ To make sure at least some slot is choosen
+ """
+ # A = A.detach()
+ B = torch.zeros_like(A, device = A.device)
+ batch_slot_leftnum = (A != 0).sum(-1)
+ lesser_column_idx = torch.nonzero(batch_slot_leftnum < lower_bound).reshape(-1)
+ for j in lesser_column_idx:
+ left_slot_mask = A[j]
+ sample_slot_zero_idx = torch.nonzero(left_slot_mask==0).reshape(-1)
+ # Generate a random permutation of indices
+ sampled_indices = torch.randperm(sample_slot_zero_idx.size(0))[:lower_bound - batch_slot_leftnum[j]]
+ sampled_elements = sample_slot_zero_idx[sampled_indices]
+ B[j][sampled_elements] += 1
+ return B
+
+class SlotAttentionGumbelV1(nn.Module):
+ """Implementation of SlotAttention with Gumbel Selection Module.
+
+ Based on the slot attention implementation of Phil Wang available at:
+ https://github.com/lucidrains/slot-attention
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ feature_dim: int,
+ kvq_dim: Optional[int] = None,
+ n_heads: int = 1,
+ iters: int = 3,
+ eps: float = 1e-8,
+ ff_mlp: Optional[nn.Module] = None,
+ use_projection_bias: bool = False,
+ use_implicit_differentiation: bool = False,
+ single_gumbel_score_network = None,
+ low_bound = 0,
+ temporature_function = None
+ ):
+ super().__init__()
+ self.dim = dim
+ self.n_heads = n_heads
+ self.iters = iters
+ self.eps = eps
+ self.use_implicit_differentiation = use_implicit_differentiation
+
+ if kvq_dim is None:
+ self.kvq_dim = dim
+ else:
+ self.kvq_dim = kvq_dim
+
+ if self.kvq_dim % self.n_heads != 0:
+ raise ValueError("Key, value, query dimensions must be divisible by number of heads.")
+ self.dims_per_head = self.kvq_dim // self.n_heads
+ self.scale = self.dims_per_head**-0.5
+
+ self.to_q = nn.Linear(dim, self.kvq_dim, bias=use_projection_bias)
+ self.to_k = nn.Linear(feature_dim, self.kvq_dim, bias=use_projection_bias)
+ self.to_v = nn.Linear(feature_dim, self.kvq_dim, bias=use_projection_bias)
+
+ self.gru = nn.GRUCell(self.kvq_dim, dim)
+
+ self.norm_input = nn.LayerNorm(feature_dim)
+ self.norm_slots = nn.LayerNorm(dim)
+ self.ff_mlp = ff_mlp
+ self.single_gumbel_score_network = single_gumbel_score_network
+ self.low_bound = low_bound
+ self.temporature_function = temporature_function
+
+ def step(self, slots, k, v, masks=None):
+ bs, n_slots, _ = slots.shape
+ slots_prev = slots
+
+ slots = self.norm_slots(slots)
+ q = self.to_q(slots).view(bs, n_slots, self.n_heads, self.dims_per_head)
+
+ dots = torch.einsum("bihd,bjhd->bihj", q, k) * self.scale
+ if masks is not None:
+ # Masked slots should not take part in the competition for features. By replacing their
+ # dot-products with -inf, their attention values will become zero within the softmax.
+ dots.masked_fill_(masks.to(torch.bool).view(bs, n_slots, 1, 1), float("-inf"))
+
+ attn = dots.flatten(1, 2).softmax(dim=1) # Take softmax over slots and heads
+ attn = attn.view(bs, n_slots, self.n_heads, -1)
+ attn_before_reweighting = attn
+ attn = attn + self.eps
+ attn = attn / attn.sum(dim=-1, keepdim=True)
+
+ updates = torch.einsum("bjhd,bihj->bihd", v, attn)
+
+ slots = self.gru(updates.reshape(-1, self.kvq_dim), slots_prev.reshape(-1, self.dim))
+
+ slots = slots.reshape(bs, -1, self.dim)
+
+ if self.ff_mlp:
+ slots = self.ff_mlp(slots)
+
+ return slots, attn_before_reweighting.mean(dim=2)
+
+ def iterate(self, slots, k, v, masks=None):
+ for _ in range(self.iters):
+ slots, attn = self.step(slots, k, v, masks)
+ return slots, attn
+
+ def forward(
+ self, inputs: torch.Tensor, conditioning: torch.Tensor, masks: Optional[torch.Tensor] = None,
+ global_step = None
+ ):
+ b, n, d = inputs.shape
+ slots = conditioning
+
+ inputs = self.norm_input(inputs)
+ k = self.to_k(inputs).view(b, n, self.n_heads, self.dims_per_head)
+ v = self.to_v(inputs).view(b, n, self.n_heads, self.dims_per_head)
+
+ if self.use_implicit_differentiation:
+ slots, attn = self.iterate(slots, k, v, masks)
+ slots, attn = self.step(slots.detach(), k, v, masks)
+ else:
+ slots, attn = self.iterate(slots, k, v, masks)
+
+ """
+ Gumbel selection
+ """
+ # b, k, d = conditioning.shape
+ _, k, _ = conditioning.shape
+ prev_decision = torch.ones(b, k, dtype=slots.dtype, device=slots.device) #prev_decision [b, k]
+
+ slots_keep_prob = self.single_gumbel_score_network(slots) #slots_keep_prob [b, k, 2]
+ if global_step == None:
+ tau = 1
+ else:
+ tau = self.temporature_function(global_step)
+ current_keep_decision = F.gumbel_softmax(slots_keep_prob, hard=True, tau = tau)[...,1]
+ if self.low_bound > 0:
+ current_keep_decision = current_keep_decision + sample_slot_lower_bound(current_keep_decision, self.low_bound)
+ hard_keep_decision = current_keep_decision * prev_decision #hard_keep_decision [b, k]
+ slots_keep_prob = F.softmax(slots_keep_prob, dim = -1)[...,1]
+ # hard_idx[...,0]
+
+ return slots, attn, slots_keep_prob, hard_keep_decision
+
+
+class SlotAttentionGroupingGumbelV1(base.PerceptualGrouping, RoutableMixin):
+ """Implementation of SlotAttention for perceptual grouping.
+
+ Args:
+ feature_dim: Dimensionality of features to slot attention (after positional encoding).
+ object_dim: Dimensionality of slots.
+ kvq_dim: Dimensionality after projecting to keys, values, and queries. If `None`,
+ `object_dim` is used.
+ n_heads: Number of heads slot attention uses.
+ iters: Number of slot attention iterations.
+ eps: Epsilon in slot attention.
+ ff_mlp: Optional module applied slot-wise after GRU update.
+ positional_embedding: Optional module applied to the features before slot attention, adding
+ positional encoding.
+ use_projection_bias: Whether to use biases in key, value, query projections.
+ use_implicit_differentiation: Whether to use implicit differentiation trick. If true,
+ performs one more iteration of slot attention that is used for the gradient step after
+ `iters` iterations of slot attention without gradients. Faster and more memory efficient
+ than the standard version, but can not backpropagate gradients to the conditioning input.
+ input_dim: Dimensionality of features before positional encoding is applied. Specifying this
+ is optional but can be convenient to structure configurations.
+ """
+
+ def __init__(
+ self,
+ feature_dim: int,
+ object_dim: int,
+ kvq_dim: Optional[int] = None,
+ n_heads: int = 1,
+ iters: int = 3,
+ eps: float = 1e-8,
+ ff_mlp: Optional[nn.Module] = None,
+ positional_embedding: Optional[nn.Module] = None,
+ use_projection_bias: bool = False,
+ use_implicit_differentiation: bool = False,
+ use_empty_slot_for_masked_slots: bool = False,
+ input_dim: Optional[int] = None,
+ feature_path: Optional[str] = path_defaults.FEATURES,
+ conditioning_path: Optional[str] = path_defaults.CONDITIONING,
+ slot_mask_path: Optional[str] = None,
+ single_gumbel_score_network: Optional[nn.Module] = None,
+ low_bound = 0,
+ temporature_function = None
+ ):
+ base.PerceptualGrouping.__init__(self)
+ RoutableMixin.__init__(
+ self,
+ {
+ "extracted_features": feature_path,
+ "conditioning": conditioning_path,
+ "slot_masks": slot_mask_path,
+ "global_step": path_defaults.GLOBAL_STEP
+ },
+ )
+
+ self._object_dim = object_dim
+ self.slot_attention = SlotAttentionGumbelV1(
+ dim=object_dim,
+ feature_dim=feature_dim,
+ kvq_dim=kvq_dim,
+ n_heads=n_heads,
+ iters=iters,
+ eps=eps,
+ ff_mlp=ff_mlp,
+ use_projection_bias=use_projection_bias,
+ use_implicit_differentiation=use_implicit_differentiation,
+ single_gumbel_score_network = single_gumbel_score_network,
+ low_bound = low_bound,
+ temporature_function = temporature_function
+ )
+
+ self.positional_embedding = positional_embedding
+
+ if use_empty_slot_for_masked_slots:
+ if slot_mask_path is None:
+ raise ValueError("Need `slot_mask_path` for `use_empty_slot_for_masked_slots`")
+ self.empty_slot = nn.Parameter(torch.randn(object_dim) * object_dim**-0.5)
+ else:
+ self.empty_slot = None
+
+ if temporature_function==None:
+ temporature_function = (lambda step: 1)
+
+ self.temporature_function = temporature_function
+
+ @property
+ def object_dim(self):
+ return self._object_dim
+
+ @RoutableMixin.route
+ def forward(
+ self,
+ extracted_features: base.FeatureExtractorOutput,
+ conditioning: base.ConditioningOutput,
+ slot_masks: Optional[torch.Tensor] = None,
+ global_step = None
+ ):
+ if self.positional_embedding:
+ features = self.positional_embedding(
+ extracted_features.features, extracted_features.positions
+ )
+ else:
+ features = extracted_features.features
+
+ slots, attn, slots_keep_prob, hard_keep_decision = self.slot_attention(features, conditioning, slot_masks, global_step = global_step)
+
+ if slot_masks is not None and self.empty_slot is not None:
+ slots[slot_masks] = self.empty_slot.to(dtype=slots.dtype)
+
+ # objects: TensorType["batch_size", "n_objects", "object_dim"] # noqa: F821
+ # is_empty: Optional[TensorType["batch_size", "n_objects"]] = None # noqa: F821
+ # feature_attributions: Optional[
+ # TensorType["batch_size", "n_objects", "n_spatial_features"] # noqa: F821
+ # ] = None
+
+ return {
+ "objects": slots,
+ "is_empty": slot_masks,
+ "feature_attributions":attn,
+ "slots_keep_prob": slots_keep_prob,
+ "hard_keep_decision":hard_keep_decision
+ }
\ No newline at end of file
diff --git a/ocl/plugins.py b/ocl/plugins.py
new file mode 100644
index 0000000..163a554
--- /dev/null
+++ b/ocl/plugins.py
@@ -0,0 +1,1839 @@
+import functools
+import logging
+import math
+import os
+import random
+from collections import defaultdict
+from io import BytesIO
+from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union
+
+import decord
+import numpy as np
+import torch
+import webdataset
+from pytorch_lightning.utilities.rank_zero import rank_zero_warn
+
+from ocl import hooks
+from ocl.utils.trees import get_tree_element
+
+decord.bridge.set_bridge("torch")
+LOGGER = logging.getLogger(__name__)
+
+
+class Plugin:
+ """A plugin which defines a set of hooks to be called by the code."""
+
+
+class Optimization(Plugin):
+ """Optimize (a subset of) the parameters using a optimizer and a LR scheduler."""
+
+ def __init__(
+ self, optimizer, lr_scheduler=None, parameter_groups: Optional[List[Dict[str, Any]]] = None
+ ):
+ self.optimizer = optimizer
+ self.lr_scheduler = lr_scheduler
+ self.parameter_group_specs = parameter_groups
+ if self.parameter_group_specs:
+ for idx, param_group_spec in enumerate(self.parameter_group_specs):
+ if "params" not in param_group_spec:
+ raise ValueError(f'Parameter group {idx + 1} does not contain key "params"')
+ param_spec = param_group_spec["params"]
+ if isinstance(param_spec, str):
+ param_group_spec["params"] = [param_spec]
+ elif isinstance(param_spec, Iterable):
+ param_group_spec["params"] = list(param_spec)
+ else:
+ raise ValueError(
+ f'"params" for parameter group {idx + 1} is not of type str or iterable'
+ )
+
+ if "predicate" in param_group_spec:
+ if not callable(param_group_spec["predicate"]):
+ raise ValueError(
+ f'"predicate" for parameter group {idx + 1} is not a callable'
+ )
+
+ def _get_parameter_groups(self, model):
+ """Build parameter groups from specification."""
+ parameter_groups = []
+ for param_group_spec in self.parameter_group_specs:
+ param_spec = param_group_spec["params"]
+ # Default predicate includes all parameters
+ predicate = param_group_spec.get("predicate", lambda name, param: True)
+
+ parameters = []
+ for parameter_path in param_spec:
+ root = model
+ for child in parameter_path.split("."):
+ root = getattr(root, child)
+ parameters.extend(
+ param for name, param in root.named_parameters() if predicate(name, param)
+ )
+
+ param_group = {
+ k: v for k, v in param_group_spec.items() if k not in ("params", "predicate")
+ }
+ param_group["params"] = parameters
+ parameter_groups.append(param_group)
+
+ return parameter_groups
+
+ @hooks.hook_implementation
+ def configure_optimizers(self, model):
+ if self.parameter_group_specs:
+ params_or_param_groups = self._get_parameter_groups(model)
+ else:
+ params_or_param_groups = model.parameters()
+
+ optimizer = self.optimizer(params_or_param_groups)
+ output = {"optimizer": optimizer}
+ if self.lr_scheduler:
+ output.update(self.lr_scheduler(optimizer))
+ return output
+
+
+class FreezeParameters(Plugin):
+ def __init__(self, parameter_groups: List[Dict[str, Any]]):
+ self.parameter_group_specs = parameter_groups
+ for idx, param_group_spec in enumerate(self.parameter_group_specs):
+ if "params" not in param_group_spec:
+ raise ValueError(f'Parameter group {idx + 1} does not contain key "params"')
+ param_spec = param_group_spec["params"]
+ if isinstance(param_spec, str):
+ param_group_spec["params"] = [param_spec]
+ elif isinstance(param_spec, Iterable):
+ param_group_spec["params"] = list(param_spec)
+ else:
+ raise ValueError(
+ f'"params" for parameter group {idx + 1} is not of type str or iterable'
+ )
+
+ if "predicate" in param_group_spec:
+ if not callable(param_group_spec["predicate"]):
+ raise ValueError(f'"predicate" for parameter group {idx + 1} is not a callable')
+
+ def _get_parameters_to_freeze(self, model):
+ """Build parameter groups from specification."""
+ parameters_to_freeze = []
+ for param_group_spec in self.parameter_group_specs:
+ for current_params in param_group_spec["params"]:
+ param_path = current_params.split(".")
+ # Default predicate includes all parameters
+ predicate = param_group_spec.get("predicate", lambda name, param: True)
+ param = get_tree_element(model, param_path)
+ if isinstance(param, torch.nn.Module):
+ parameters_to_freeze.extend(
+ param for name, param in param.named_parameters() if predicate(name, param)
+ )
+ elif isinstance(param, torch.nn.Parameter):
+ parameters_to_freeze.append(param)
+ else:
+ raise ValueError(
+ "Object at path {'.'.join(param_path)} is neither nn.Module nor nn.Parameter"
+ )
+ return parameters_to_freeze
+
+ @hooks.hook_implementation
+ def on_train_start(self, model):
+ parameters_to_freeze = self._get_parameters_to_freeze(model)
+ for param in parameters_to_freeze:
+ param.requires_grad_(False)
+
+
+class RestoreParameterSubset(Plugin):
+ """Restore a subset of parameters using a checkpoint form a different model."""
+
+ def __init__(self, checkpoint_file: str, target_path: str, source_path: Optional[str] = None):
+ self.checkpoint_file = checkpoint_file
+ self.target_path = target_path
+ self.source_path = source_path if source_path else self.target_path
+
+ @hooks.hook_implementation
+ def on_train_start(self, model):
+ if model.global_step != 0:
+ # Don't restore when we are resuming training.
+ rank_zero_warn("Not restoring parameter subset as training is being resumed")
+ return
+ device = model.device
+ # Get parameters from state dict, load to cpu first to avoid memory issues.
+ state_dict = torch.load(self.checkpoint_file, map_location="cpu")["state_dict"]
+ # Add offset of 1 to remove potential dot.
+ offset_keys = len(self.source_path) + 1
+ state_dict = {
+ key[offset_keys:]: value
+ for key, value in state_dict.items()
+ if key.startswith(self.source_path)
+ }
+
+ # Get module from model
+ model_component: torch.nn.Module = get_tree_element(model, self.target_path.split("."))
+ result = model_component.load_state_dict(state_dict, strict=False)
+ model_component.to(device=device)
+ if len(result.missing_keys):
+ rank_zero_warn(
+ f"Mismatch between state dict and model. Missing keys: {result.missing_keys}"
+ )
+ if len(result.unexpected_keys):
+ rank_zero_warn(
+ f"Mismatch between state dict and model. Unexpected keys: {result.missing_keys}"
+ )
+
+
+def transform_with_duplicate(elements: dict, *, transform, element_key: str, duplicate_key: str):
+ """Utility function to fix issues with pickling."""
+ element = transform(elements[element_key])
+ elements[element_key] = element
+ elements[duplicate_key] = element
+ return elements
+
+
+class SingleElementPreprocessing(Plugin):
+ """Preprocessing of a single element in the input data.
+
+ This is useful to build preprocessing pipelines based on existing element transformations such
+ as those provided by torchvision. The element can optionally be duplicated and stored under a
+ different key after the transformation by specifying `duplicate_key`. This is useful to further
+ preprocess this element in different ways afterwards.
+ """
+
+ def __init__(
+ self,
+ training_transform: Callable,
+ evaluation_transform: Callable,
+ element_key: str = "image",
+ duplicate_key: Optional[str] = None,
+ ):
+ self._training_transform = training_transform
+ self._evaluation_transform = evaluation_transform
+ self.element_key = element_key
+ self.duplicate_key = duplicate_key
+
+ @hooks.hook_implementation
+ def training_fields(self):
+ return (self.element_key,)
+
+ @hooks.hook_implementation
+ def training_transform(self):
+ if self._training_transform:
+
+ if self.duplicate_key is None:
+
+ def transform(pipeline: webdataset.Processor):
+ return pipeline.map_dict(**{self.element_key: self._training_transform})
+
+ else:
+
+ def transform(pipeline: webdataset.Processor):
+ transform_func = functools.partial(
+ transform_with_duplicate,
+ transform=self._training_transform,
+ element_key=self.element_key,
+ duplicate_key=self.duplicate_key,
+ )
+ return pipeline.map(transform_func)
+
+ return transform
+ else:
+ return None
+
+ @hooks.hook_implementation
+ def evaluation_fields(self):
+ return (self.element_key,)
+
+ @hooks.hook_implementation
+ def evaluation_transform(self):
+ if self._evaluation_transform:
+
+ if self.duplicate_key is None:
+
+ def transform(pipeline: webdataset.Processor):
+ return pipeline.map_dict(**{self.element_key: self._evaluation_transform})
+
+ else:
+
+ def transform(pipeline: webdataset.Processor):
+ transform_func = functools.partial(
+ transform_with_duplicate,
+ transform=self._evaluation_transform,
+ element_key=self.element_key,
+ duplicate_key=self.duplicate_key,
+ )
+
+ return pipeline.map(transform_func)
+
+ return transform
+ else:
+ return None
+
+
+class MultiElementPreprocessing(Plugin):
+ """Preprocessing of multiple elements in the input data.
+
+ This is useful preprocessing pipelines based on existing element transformations such as those
+ provided by torchvision.
+ """
+
+ def __init__(
+ self,
+ training_transforms: Optional[Dict[str, Any]] = None,
+ evaluation_transforms: Optional[Dict[str, Any]] = None,
+ ):
+ if training_transforms is None:
+ training_transforms = {}
+ self.training_keys = tuple(training_transforms)
+ self._training_transforms = {
+ key: transf for key, transf in training_transforms.items() if transf is not None
+ }
+
+ if evaluation_transforms is None:
+ evaluation_transforms = {}
+ self.evaluation_keys = tuple(evaluation_transforms)
+ self._evaluation_transforms = {
+ key: transf for key, transf in evaluation_transforms.items() if transf is not None
+ }
+
+ @hooks.hook_implementation
+ def training_fields(self):
+ return self.training_keys
+
+ @hooks.hook_implementation
+ def training_transform(self):
+ if self._training_transforms:
+
+ def transform(pipeline: webdataset.Processor):
+ return pipeline.map_dict(**self._training_transforms)
+
+ return transform
+ else:
+ return None
+
+ @hooks.hook_implementation
+ def evaluation_fields(self):
+ return self.evaluation_keys
+
+ @hooks.hook_implementation
+ def evaluation_transform(self):
+ if self._evaluation_transforms:
+
+ def transform(pipeline: webdataset.Processor):
+ return pipeline.map_dict(**self._evaluation_transforms)
+
+ return transform
+ else:
+ return None
+
+
+class DataPreprocessing(Plugin):
+ """Arbitrary preprocessing of input data.
+
+ The transform takes in a dictionary of elements and should return a dictionary of elements.
+ Plugin must specify the elements that should be included in the dictionary using
+ `training_fields` and `evaluation_fields` arguments.
+ """
+
+ def __init__(
+ self,
+ training_transform: Optional[Callable] = None,
+ evaluation_transform: Optional[Callable] = None,
+ training_fields: Optional[Sequence[str]] = None,
+ evaluation_fields: Optional[Sequence[str]] = None,
+ ):
+ if training_transform is not None and training_fields is None:
+ raise ValueError(
+ "If passing `training_transform`, `training_fields` must also be specified."
+ )
+ if evaluation_transform is not None and evaluation_fields is None:
+ raise ValueError(
+ "If passing `evaluation_transform`, `evaluation_fields` must also be specified."
+ )
+
+ self._training_transform = training_transform
+ self._evaluation_transform = evaluation_transform
+ self._training_fields = tuple(training_fields) if training_fields else tuple()
+ self._evaluation_fields = tuple(evaluation_fields) if evaluation_fields else tuple()
+
+ @hooks.hook_implementation
+ def training_fields(self):
+ return self._training_fields
+
+ @hooks.hook_implementation
+ def training_transform(self):
+ if self._training_transform:
+
+ def transform(pipeline: webdataset.Processor):
+ return pipeline.map(self._training_transform)
+
+ return transform
+ else:
+ return None
+
+ @hooks.hook_implementation
+ def evaluation_fields(self):
+ return self._evaluation_fields
+
+ @hooks.hook_implementation
+ def evaluation_transform(self):
+ if self._evaluation_transform:
+
+ def transform(pipeline: webdataset.Processor):
+ return pipeline.map(self._evaluation_transform)
+
+ return transform
+ else:
+ return None
+
+
+class BatchDataPreprocessing(Plugin):
+ """Arbitrary preprocessing of input data batch.
+
+ The transform takes in a dictionary of elements and should return a dictionary of elements.
+ Plugin must specify the elements that should be included in the dictionary using
+ `training_fields` and `evaluation_fields` arguments.
+ """
+
+ def __init__(
+ self,
+ training_transform: Optional[Callable] = None,
+ evaluation_transform: Optional[Callable] = None,
+ training_fields: Optional[Sequence[str]] = None,
+ evaluation_fields: Optional[Sequence[str]] = None,
+ ):
+ if training_transform is not None and training_fields is None:
+ raise ValueError(
+ "If passing `training_transform`, `training_fields` must also be specified."
+ )
+ if evaluation_transform is not None and evaluation_fields is None:
+ raise ValueError(
+ "If passing `evaluation_transform`, `evaluation_fields` must also be specified."
+ )
+
+ self._training_transform = training_transform
+ self._evaluation_transform = evaluation_transform
+ self._training_fields = tuple(training_fields) if training_fields else tuple()
+ self._evaluation_fields = tuple(evaluation_fields) if evaluation_fields else tuple()
+
+ @hooks.hook_implementation
+ def training_fields(self):
+ return self._training_fields
+
+ @hooks.hook_implementation
+ def training_batch_transform(self):
+ if self._training_transform:
+
+ def transform(pipeline: webdataset.Processor):
+ return pipeline.map(self._training_transform)
+
+ return transform
+ else:
+ return None
+
+ @hooks.hook_implementation
+ def evaluation_fields(self):
+ return self._evaluation_fields
+
+ @hooks.hook_implementation
+ def evaluation_batch_transform(self):
+ if self._evaluation_transform:
+
+ def transform(pipeline: webdataset.Processor):
+ return pipeline.map(self._evaluation_transform)
+
+ return transform
+ else:
+ return None
+
+
+def _transform_elements(inputs, transforms):
+ for key, transform in transforms.items():
+ inputs[key] = transform(inputs[key])
+ return inputs
+
+
+class MultiElementBatchPreprocessing(BatchDataPreprocessing):
+ """Preprocessing of multiple elements in the batched input data.
+
+ This is useful preprocessing pipelines based on existing element transformations such as those
+ provided by torchvision.
+ """
+
+ def __init__(
+ self,
+ training_transforms: Optional[Dict[str, Any]] = None,
+ evaluation_transforms: Optional[Dict[str, Any]] = None,
+ ):
+ if training_transforms is None:
+ training_transform = None
+ training_fields = None
+ else:
+ training_fields = tuple(training_transforms)
+ training_transforms = {
+ key: transf for key, transf in training_transforms.items() if transf is not None
+ }
+ training_transform = functools.partial(
+ _transform_elements, transforms=training_transforms
+ )
+
+ if evaluation_transforms is None:
+ evaluation_transform = None
+ evaluation_fields = None
+ else:
+ evaluation_fields = tuple(evaluation_transforms)
+ evaluation_transforms = {
+ key: transf for key, transf in evaluation_transforms.items() if transf is not None
+ }
+ evaluation_transform = functools.partial(
+ _transform_elements, transforms=evaluation_transforms
+ )
+ super().__init__(
+ training_transform=training_transform,
+ evaluation_transform=evaluation_transform,
+ training_fields=training_fields,
+ evaluation_fields=evaluation_fields,
+ )
+
+
+def _transform_single_element(inputs, field, transform, duplicate_key):
+ # print("in _transform_single_element")
+ if duplicate_key:
+ inputs[duplicate_key] = inputs[field]
+ inputs[field] = transform(inputs[field])
+ return inputs
+
+
+class SingleElementBatchPreprocessing(BatchDataPreprocessing):
+ """Preprocessing of a single element in the batched input data.
+
+ This is useful to build preprocessing pipelines based on existing element transformations such
+ as those provided by torchvision. The element can optionally be duplicated and stored under a
+ different key after the transformation by specifying `duplicate_key`. This is useful to further
+ preprocess this element in different ways afterwards.
+ """
+
+ def __init__(
+ self,
+ training_transform: Optional[Callable],
+ evaluation_transform: Optional[Callable],
+ element_key: str = "image",
+ duplicate_key: Optional[str] = None,
+ ):
+ if training_transform is None:
+ training_fields = None
+ else:
+ training_fields = [element_key]
+ training_transform = functools.partial(
+ _transform_single_element,
+ field=element_key,
+ transform=training_transform,
+ duplicate_key=duplicate_key,
+ )
+
+ if evaluation_transform is None:
+ evaluation_fields = None
+ else:
+ evaluation_fields = [element_key]
+ evaluation_transform = functools.partial(
+ _transform_single_element,
+ field=element_key,
+ transform=evaluation_transform,
+ duplicate_key=duplicate_key,
+ )
+
+ super().__init__(
+ training_transform=training_transform,
+ evaluation_transform=evaluation_transform,
+ training_fields=training_fields,
+ evaluation_fields=evaluation_fields,
+ )
+
+
+class SubsetDataset(Plugin):
+ """Create a subset of a dataset by discarding samples."""
+
+ def __init__(
+ self, predicate, fields: Sequence[str], subset_train: bool = True, subset_eval: bool = True
+ ):
+ """Plugin to create a subset of a dataset by discarding samples.
+
+ Args:
+ predicate: Function which determines if elements should be kept (return value is True)
+ or discarded (return value is False). The function is only provided with the fields
+ specified in the `fields` parameter.
+ fields (Sequence[str]): The fields from the input which should be passed on to the
+ predicate for evaluation.
+ subset_train: Subset training data.
+ subset_eval: Subset evaluation data.
+ """
+ self.predicate = predicate
+ self.fields = tuple(fields)
+ self.subset_train = subset_train
+ self.subset_eval = subset_eval
+
+ def _get_transform_function(self):
+ def wrapped_predicate(d: dict):
+ return self.predicate(*(d[field] for field in self.fields))
+
+ def select(pipeline: webdataset.Processor):
+ return pipeline.select(wrapped_predicate)
+
+ return select
+
+ @hooks.hook_implementation
+ def training_fields(self):
+ if self.subset_train:
+ return self.fields
+ else:
+ return tuple()
+
+ @hooks.hook_implementation
+ def training_transform(self):
+ if self.subset_train:
+ return self._get_transform_function()
+ else:
+ return None
+
+ @hooks.hook_implementation
+ def evaluation_fields(self):
+ if self.subset_eval:
+ return self.fields
+ else:
+ return tuple()
+
+ @hooks.hook_implementation
+ def evaluation_transform(self):
+ if self.subset_eval:
+ return self._get_transform_function()
+ else:
+ return None
+
+
+class SampleFramesFromVideo(Plugin):
+ def __init__(
+ self,
+ n_frames_per_video: int,
+ training_fields: Sequence[str],
+ evaluation_fields: Sequence[str],
+ dim: int = 0,
+ seed: int = 39480234,
+ per_epoch: bool = False,
+ shuffle_buffer_size: int = 1000,
+ n_eval_frames_per_video: Optional[int] = None,
+ ):
+ """Sample frames from input tensors.
+
+ Args:
+ n_frames_per_video: Number of frames per video to sample. -1 indicates that all frames
+ should be sampled.
+ training_fields: The fields that should be considered video data and thus sliced
+ according to the frame sampling during training.
+ evaluation_fields: The fields that should be considered video data and thus sliced
+ according to the frame sampling during evaluation.
+ dim: The dimension along which to slice the tensors.
+ seed: Random number generator seed to deterministic sampling during evaluation.
+ per_epoch: Sampling of frames over epochs, this ensures that after
+ n_frames / n_frames_per_video epochs all frames have been seen at least once.
+ In the case of uneven division, some frames will be seen more than once.
+ shuffle_buffer_size: Size of shuffle buffer used during training. An additional
+ shuffling step ensures each batch contains a diverse set of images and not only
+ images from the same video.
+ n_eval_frames_per_video: Number of frames per video to sample on the evaluation splits.
+ """
+ self.n_frames_per_video = n_frames_per_video
+ self._training_fields = tuple(training_fields)
+ self._evaluation_fields = tuple(evaluation_fields)
+ self.dim = dim
+ self.seed = seed
+ self.per_epoch = per_epoch
+ self.shuffle_buffer_size = shuffle_buffer_size
+ if n_eval_frames_per_video is not None:
+ self.n_eval_frames_per_video = n_eval_frames_per_video
+ else:
+ self.n_eval_frames_per_video = n_frames_per_video
+
+ def slice_data(self, data, index: int):
+ """Small utility method to slice a numpy array along a specified axis."""
+ n_dims_before = self.dim
+ n_dims_after = data.ndim - 1 - self.dim
+ slices = (slice(None),) * n_dims_before + (index,) + (slice(None),) * n_dims_after
+ return data[slices]
+
+ def sample_frames_using_key(self, data, fields, seed, n_frames_per_video):
+ """Sample frames deterministically from generator of videos using the __key__ field."""
+ # import ipdb
+ # ipdb.set_trace()
+ # print("in sample_frames_using_key")
+ # print(fields)
+ for sample in data:
+ # Initialize random number generator dependent on instance key. This should make the
+ # sampling process deterministic, which is useful when sampling frames for the
+
+ # print(sample.keys())
+ # print('sample["image"].shape')
+ # print(sample["image"].shape)
+ # print('type(sample["image"])')
+ # print(type(sample["image"]))
+ # print({key:type(value) for key, value in sample.items()})
+ # print({key:sample[key].shape for key in fields})
+ # for key, value in sample.items():
+ # if isinstance( value, np.ndarray):
+ # print(key,value.dtype)
+ # validation/test data.
+ key = sample["__key__"]
+
+
+ n_frames = sample[fields[0]].shape[self.dim]
+ frames_per_video = self.n_frames_per_video if self.n_frames_per_video != -1 else n_frames
+
+ if self.per_epoch and self.n_frames_per_video != -1:
+ n_different_epochs_per_seed = int(math.ceil(n_frames / frames_per_video))
+ try:
+ epoch = int(os.environ["WDS_EPOCH"])
+ except KeyError:
+ raise RuntimeError(
+ "Using SampleFramesFromVideo with stratify=True "
+ "requires `WDS_EPOCH` to be set."
+ )
+ # Only update the seed after n_frames / n_frames_per_video epochs.
+ # This ensures that we get the same random order of frames until
+ # we have sampled all of them.
+ rand = np.random.RandomState(
+ int(key) + seed + (epoch // n_different_epochs_per_seed)
+ )
+ indices = rand.permutation(n_frames)
+ selected_frames = indices[
+ epoch * self.n_frames_per_video : (epoch + 1) * self.n_frames_per_video
+ ].tolist()
+ if len(selected_frames) < self.n_frames_per_video:
+ # Input cannot be evenly split, take some frames from the first batch of frames.
+ n_missing = self.n_frames_per_video - len(selected_frames)
+ selected_frames.extend(indices[0:n_missing].tolist())
+ else:
+ rand = random.Random(int(key) + seed)
+ selected_frames = rand.sample(range(n_frames), k=frames_per_video)
+
+ for frame in selected_frames:
+ # Slice the fields according to the frame, we use copy in order to allow freeing of
+ # the original tensor.
+ sliced_fields = {
+ field: self.slice_data(sample[field], frame).copy() for field in fields
+ }
+ # Leave all fields besides the sliced ones as before, augment the __key__ field to
+ # include the frame number.
+ sliced_fields["__key__"] = f"{key}_{frame}"
+ to_return = {**sample, **sliced_fields}
+ # for key, value in to_return.items():
+ # if isinstance(value, np.ndarray):
+ # print(key, value.shape)
+ # yield to_return
+ yield {**sample, **sliced_fields}
+
+ # Delete fields to be sure we remove all references.
+ for field in fields:
+ del sample[field]
+
+ @hooks.hook_implementation
+ def training_fields(self):
+ return self._training_fields
+
+ @hooks.hook_implementation
+ def training_transform(self):
+ def apply_deterministic_sampling(pipeline: webdataset.Processor):
+ if len(self._training_fields) > 0:
+ return pipeline.then(
+ functools.partial(
+ self.sample_frames_using_key,
+ fields=self._training_fields,
+ seed=self.seed,
+ n_frames_per_video=self.n_frames_per_video,
+ )
+ ).shuffle(self.shuffle_buffer_size)
+ else:
+ return pipeline
+
+ return apply_deterministic_sampling
+
+ @hooks.hook_implementation
+ def evaluation_fields(self):
+ return self._evaluation_fields
+
+ @hooks.hook_implementation
+ def evaluation_transform(self):
+ def apply_deterministic_sampling(pipeline: webdataset.Processor):
+ if len(self._evaluation_fields) > 0:
+ return pipeline.then(
+ functools.partial(
+ self.sample_frames_using_key,
+ fields=self._evaluation_fields,
+ seed=self.seed + 1,
+ n_frames_per_video=self.n_eval_frames_per_video,
+ )
+ )
+ else:
+ return pipeline
+
+ return apply_deterministic_sampling
+
+
+class SplitConsecutiveFrames(Plugin):
+ def __init__(
+ self,
+ n_consecutive_frames: int,
+ training_fields: Sequence[str],
+ evaluation_fields: Sequence[str],
+ dim: int = 0,
+ shuffle_buffer_size: int = 1000,
+ drop_last: bool = True,
+ ):
+ self.n_consecutive_frames = n_consecutive_frames
+ self._training_fields = tuple(training_fields)
+ self._evaluation_fields = tuple(evaluation_fields)
+ self.dim = dim
+ self.shuffle_buffer_size = shuffle_buffer_size
+ self.drop_last = drop_last
+
+ @hooks.hook_implementation
+ def training_fields(self):
+ return self._training_fields
+
+ def split_to_consecutive_frames(self, data, fields):
+ """Sample frames deterministically from generator of videos using the __key__ field."""
+ for sample in data:
+ key = sample["__key__"]
+ n_frames = sample[fields[0]].shape[self.dim]
+
+ splitted_fields = [
+ np.array_split(
+ sample[field],
+ range(self.n_consecutive_frames, n_frames, self.n_consecutive_frames),
+ axis=self.dim,
+ )
+ for field in fields
+ ]
+
+ for i, slices in enumerate(zip(*splitted_fields)):
+ if self.drop_last and slices[0].shape[self.dim] < self.n_consecutive_frames:
+ # Last slice of not equally divisible input, discard.
+ continue
+
+ sliced_fields = dict(zip(fields, slices))
+ sliced_fields["__key__"] = f"{key}_{i}"
+ yield {**sample, **sliced_fields}
+
+ @hooks.hook_implementation
+ def training_transform(self):
+ def apply_deterministic_sampling(pipeline: webdataset.Processor):
+ if len(self._training_fields) > 0:
+ return pipeline.then(
+ functools.partial(self.split_to_consecutive_frames, fields=self._training_fields)
+ ).shuffle(self.shuffle_buffer_size)
+ else:
+ return pipeline
+
+ return apply_deterministic_sampling
+
+ @hooks.hook_implementation
+ def evaluation_fields(self):
+ return self._evaluation_fields
+
+ @hooks.hook_implementation
+ def evaluation_transform(self):
+ def apply_deterministic_sampling(pipeline: webdataset.Processor):
+ if len(self._evaluation_fields) > 0:
+ return pipeline.then(
+ functools.partial(
+ self.split_to_consecutive_frames,
+ fields=self._evaluation_fields,
+ )
+ )
+ else:
+ return pipeline
+
+ return apply_deterministic_sampling
+
+
+class VideoDecoder(Plugin):
+ """Video decoder based on torchaudio StreamReader."""
+
+ def __init__(
+ self,
+ input_fields: Union[List[str], str],
+ stride: int = 1,
+ split_extension: bool = True,
+ video_reader_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ """Video decoder based on decord.
+
+ It will decode the whole video into a single tensor and can be used with other downstream
+ processing plugins.
+
+ Args:
+ input_fields (str): The field of the input dictionary containing the video bytes.
+ stride (int): Downsample frames by using striding. Default: 1
+ split_extension (bool): Split the extension off the field name.
+ video_reader_kwargs (Dict[str, Any]): Arguments to decord.VideoReader.
+ """
+ self.input_fields = list(input_fields) if isinstance(input_fields, list) else [input_fields]
+ self.stride = stride
+ self.split_extension = split_extension
+ self.video_reader_kwargs = video_reader_kwargs if video_reader_kwargs else {}
+
+ def _chunk_iterator(
+ self, vrs: Mapping[str, decord.VideoReader], key: str, inputs: Dict[str, Any]
+ ) -> Tuple[str, torch.Tensor]:
+ """Iterate over chunks of the video.
+
+ For the video decoder we simply return a single chunk containing the whole video, subclasses
+ might override this method though.
+
+ Returns:
+ str: Derived key which combines chunk and video key.
+ torch.Tensor: Chunk of video data.
+ Dict: Additional information, for example which frames where selected. This might be of
+ relevance when different modalities need to be sliced in a similar fashion as the
+ video input.
+ """
+ # Get whole video.
+ indices = list(range(0, len(next(iter(vrs.values()))), self.stride))
+ videos = {output_name: vr.get_batch(indices) for output_name, vr in vrs.items()}
+ yield key, {**videos, "decoded_indices": indices}
+
+ @hooks.hook_implementation
+ def training_fields(self):
+ return self.input_fields
+
+ @hooks.hook_implementation
+ def evaluation_fields(self):
+ return self.input_fields
+
+ def video_decoding(self, input_generator, chunking):
+ for input_data in input_generator:
+ key = input_data["__key__"]
+ vrs = {}
+ for input_field in self.input_fields:
+ video_bytes: bytes = input_data[input_field]
+ if self.split_extension:
+ output_field, _ = os.path.splitext(input_field)
+ else:
+ output_field = input_field
+ # Remove the input field
+ del input_data[input_field]
+ with BytesIO(video_bytes) as f:
+ # We can directly close the file again as VideoReader makes an internal copy.
+ vr = decord.VideoReader(f, **self.video_reader_kwargs)
+ vrs[output_field] = vr
+
+ for derived_key, videos_and_additional_info in chunking(vrs, key, input_data):
+ yield {
+ **input_data,
+ "__key__": derived_key,
+ **videos_and_additional_info,
+ }
+
+ @hooks.hook_implementation
+ def training_transform(self):
+ return lambda pipeline: pipeline.then(
+ functools.partial(self.video_decoding, chunking=self._chunk_iterator)
+ )
+
+ @hooks.hook_implementation
+ def evaluation_transform(self):
+ return lambda pipeline: pipeline.then(
+ functools.partial(self.video_decoding, chunking=self._chunk_iterator)
+ )
+
+
+class DecodeRandomWindow(VideoDecoder):
+ """Decode a random window of the video."""
+
+ def __init__(self, n_consecutive_frames: int, **video_decoder_args):
+ self.n_consecutive_frames = n_consecutive_frames
+ self._random = None
+ super().__init__(**video_decoder_args)
+
+ @property
+ def random(self):
+ if not self._random:
+ worker_info = torch.utils.data.get_worker_info()
+ if worker_info:
+ self._random = random.Random(worker_info.seed)
+ else:
+ self._random = random.Random(torch.initial_seed())
+
+ return self._random
+
+ def _chunk_iterator(
+ self, vrs: Mapping[str, decord.VideoReader], key: str, inputs: Dict[str, Any]
+ ) -> Tuple[str, torch.Tensor]:
+ """Iterate over chunks of the video.
+
+ Returns:
+ str: Derived key which combines chunk and video key.
+ torch.Tensor: Chunk of video data.
+ Dict: Additional information, for example which frames where selected. This might be of
+ relevance when different modalities need to be sliced in a similar fashion as the
+ video input.
+ """
+ n_frames = len(next(iter(vrs.values())))
+ assert self.n_consecutive_frames * self.stride < n_frames
+ starting_index = self.random.randint(0, n_frames - self.n_consecutive_frames * self.stride)
+ indices = list(
+ range(
+ starting_index, starting_index + self.n_consecutive_frames * self.stride, self.stride
+ )
+ )
+ videos = {output_field: vr.get_batch(indices) for output_field, vr in vrs.items()}
+ yield f"{key}_{starting_index}", {**videos, "decoded_indices": indices}
+
+ @hooks.hook_implementation
+ def evaluation_transform(self):
+ # Do not split during evaluation.
+ return lambda pipeline: pipeline.then(
+ functools.partial(
+ self.video_decoding, chunking=functools.partial(VideoDecoder._chunk_iterator, self)
+ )
+ )
+
+
+class DecodeRandomStridedWindow(DecodeRandomWindow):
+ """Decode random strided segment of input video."""
+
+ def _chunk_iterator(
+ self, vrs: Mapping[str, decord.VideoReader], key: str, inputs: Dict[str, Any]
+ ) -> Tuple[str, torch.Tensor]:
+ """Iterate over chunks of the video.
+
+ For the video decoder we simply return a single chunk containing the whole video, subclasses
+ might override this method though.
+
+ Returns:
+ str: Derived key which combines chunk and video key.
+ torch.Tensor: Chunk of video data.
+ Dict: Additional information, for example which frames where selected. This might be of
+ relevance when different modalities need to be sliced in a similar fashion as the
+ video input.
+ """
+ n_frames = len(next(iter(vrs.values())))
+ segment_indices = list(range(0, n_frames + 1, self.n_consecutive_frames * self.stride))
+ segment_index = self.random.randint(0, len(segment_indices) - 2)
+ indices = list(
+ range(segment_indices[segment_index], segment_indices[segment_index + 1], self.stride)
+ )
+ videos = {output_field: vr.get_batch(indices) for output_field, vr in vrs.items()}
+ yield f"{key}_{segment_index}", {**videos, "decoded_indices": indices}
+
+
+class RandomStridedWindow(Plugin):
+ """Select a random consecutive subsequence of frames in a strided manner.
+
+ Given a sequence of [1, 2, 3, 4, 5, 6, 7, 8, 9] this will return one of
+ [1, 2, 3] [4, 5, 6] [7, 8, 9].
+ """
+
+ def __init__(
+ self,
+ n_consecutive_frames: int,
+ training_fields: Sequence[str],
+ evaluation_fields: Sequence[str],
+ dim: int = 0,
+ ):
+ self.n_consecutive_frames = n_consecutive_frames
+ self._training_fields = tuple(training_fields)
+ self._evaluation_fields = tuple(evaluation_fields)
+ self.dim = dim
+ self._random = None
+
+ @hooks.hook_implementation
+ def training_fields(self):
+ return self._training_fields
+
+ @property
+ def random(self):
+ if not self._random:
+ worker_info = torch.utils.data.get_worker_info()
+ if worker_info:
+ self._random = random.Random(worker_info.seed)
+ else:
+ self._random = random.Random(torch.initial_seed())
+
+ return self._random
+
+ def split_to_consecutive_frames(self, data, fields):
+ """Sample frames deterministically from generator of videos using the __key__ field."""
+ for sample in data:
+ key = sample["__key__"]
+ n_frames = sample[fields[0]].shape[self.dim]
+
+ splitted_fields = [
+ np.array_split(
+ sample[field],
+ range(self.n_consecutive_frames, n_frames, self.n_consecutive_frames),
+ axis=self.dim,
+ )
+ for field in fields
+ ]
+
+ n_fragments = len(splitted_fields[0])
+
+ if len(splitted_fields[0][-1] < self.n_consecutive_frames):
+ # Discard last fragment if too short.
+ n_fragments -= 1
+
+ fragment_id = self.random.randint(0, n_fragments - 1)
+ sliced_fields = {
+ field_name: splitted_field[fragment_id]
+ for field_name, splitted_field in zip(fields, splitted_fields)
+ }
+ sliced_fields["__key__"] = f"{key}_{fragment_id}"
+ yield {**sample, **sliced_fields}
+
+ @hooks.hook_implementation
+ def training_transform(self):
+ def apply_deterministic_sampling(pipeline: webdataset.Processor):
+ if len(self._training_fields) > 0:
+ return pipeline.then(
+ functools.partial(self.split_to_consecutive_frames, fields=self._training_fields)
+ )
+ else:
+ return pipeline
+
+ return apply_deterministic_sampling
+
+ @hooks.hook_implementation
+ def evaluation_fields(self):
+ return self._evaluation_fields
+
+ @hooks.hook_implementation
+ def evaluation_transform(self):
+ def apply_deterministic_sampling(pipeline: webdataset.Processor):
+ if len(self._evaluation_fields) > 0:
+ return pipeline.then(
+ functools.partial(
+ self.split_to_consecutive_frames,
+ fields=self._evaluation_fields,
+ )
+ )
+ else:
+ return pipeline
+
+ return apply_deterministic_sampling
+
+
+def rename_according_to_mapping(input: dict, mapping: dict):
+ # print("In raname")
+ output = {key: value for key, value in input.items() if key not in mapping.keys()}
+ for source, target in mapping.items():
+ output[target] = input[source]
+ # if isinstance(output[target], np.ndarray):
+ # print(target, output[target].shape)
+ return output
+
+
+class RenameFields(Plugin):
+ def __init__(
+ self,
+ train_mapping: Optional[Dict[str, str]] = None,
+ evaluation_mapping: Optional[Dict[str, str]] = None,
+ ):
+ super().__init__()
+ self.train_mapping = train_mapping if train_mapping else {}
+ self.evaluation_mapping = evaluation_mapping if evaluation_mapping else {}
+
+ @hooks.hook_implementation
+ def training_fields(self) -> Tuple[str]:
+ return tuple(self.train_mapping.keys())
+
+ @hooks.hook_implementation
+ def training_transform(self):
+ def rename_fields(pipeline: webdataset.Processor):
+ if len(self.training_fields()):
+ return pipeline.map(
+ functools.partial(rename_according_to_mapping, mapping=self.train_mapping)
+ )
+ else:
+ return pipeline
+
+ return rename_fields
+
+ # Do same thing during training and testing.
+ @hooks.hook_implementation
+ def evaluation_fields(self) -> Tuple[str]:
+ return tuple(self.evaluation_mapping.keys())
+
+ @hooks.hook_implementation
+ def evaluation_transform(self):
+ def rename_fields(pipeline: webdataset.Processor):
+ if len(self.evaluation_fields()):
+ return pipeline.map(
+ functools.partial(rename_according_to_mapping, mapping=self.evaluation_mapping)
+ )
+ else:
+ return pipeline
+
+ return rename_fields
+
+
+class DeterministicSubsampleWithMasking(Plugin):
+ def __init__(
+ self,
+ samples_per_instance: int,
+ training_fields: Optional[List[str]] = None,
+ evaluation_fields: Optional[List[str]] = None,
+ mask_field: Optional[str] = None,
+ seed: int = 42,
+ ):
+ super().__init__()
+ self.samples_per_instance = samples_per_instance
+ self._training_fields = training_fields if training_fields else []
+ self._evaluation_fields = evaluation_fields if evaluation_fields else []
+ self.mask_field = mask_field
+ self.seed = seed
+
+ @hooks.hook_implementation
+ def training_fields(self):
+ return self._training_fields
+
+ def subsample_with_masking(self, instance, fields):
+ key = instance["__key__"]
+ random_state = random.Random(int(key) + self.seed)
+ n_frames = instance[fields[0]].shape[0]
+ indices = np.array(random_state.sample(range(n_frames), self.samples_per_instance))
+
+ output = instance.copy()
+ for field in fields:
+ values_to_keep = instance[field][indices]
+ field_output = np.full_like(instance[field], np.NaN)
+ field_output[indices] = values_to_keep
+ output[field] = field_output
+
+ if self.mask_field:
+ mask = np.zeros(n_frames, dtype=bool)
+ mask[indices] = True
+ output[self.mask_field] = mask
+
+ return output
+
+ @hooks.hook_implementation
+ def training_transform(self):
+ def subsample_with_masking(pipeline: webdataset.Processor):
+ if len(self._training_fields) > 0:
+ return pipeline.map(
+ functools.partial(self.subsample_with_masking, fields=self._training_fields)
+ )
+ else:
+ return pipeline
+
+ return subsample_with_masking
+
+ @hooks.hook_implementation
+ def evaluation_fields(self):
+ return self._evaluation_fields
+
+ @hooks.hook_implementation
+ def evaluation_transform(self):
+ def subsample_with_masking(pipeline: webdataset.Processor):
+ if len(self._evaluation_fields) > 0:
+ return pipeline.map(
+ functools.partial(
+ self.subsample_with_masking,
+ fields=self._evaluation_fields,
+ )
+ )
+ else:
+ return pipeline
+
+ return subsample_with_masking
+
+
+class SpatialSlidingWindow(Plugin):
+ """Split image data spatially by sliding a window across."""
+
+ def __init__(
+ self,
+ window_size: Tuple[int, int],
+ stride: Tuple[int, int],
+ padding: Tuple[int, int, int, int],
+ training_fields: Sequence[str],
+ evaluation_fields: Sequence[str],
+ expected_n_windows: Optional[int] = None,
+ ):
+ self.window_size = window_size
+ self.stride = stride
+ self.padding = padding
+ self.expected_n_windows = expected_n_windows
+ self._training_fields = tuple(training_fields) if training_fields else tuple()
+ self._evaluation_fields = tuple(evaluation_fields) if evaluation_fields else tuple()
+
+ @hooks.hook_implementation
+ def training_fields(self):
+ return self._training_fields
+
+ @staticmethod
+ def pad(elem, padding):
+ if elem.shape[-1] != 1 and elem.shape[-1] != 3:
+ elem = elem[..., None]
+ orig_height = elem.shape[-3]
+ orig_width = elem.shape[-2]
+
+ p_left, p_top, p_right, p_bottom = padding
+ height = orig_height + p_top + p_bottom
+ width = orig_width + p_left + p_right
+
+ padded_shape = list(elem.shape[:-3]) + [height, width, elem.shape[-1]]
+ elem_padded = np.zeros_like(elem, shape=padded_shape)
+ elem_padded[..., p_top : p_top + orig_height, p_left : p_left + orig_width, :] = elem
+
+ return elem_padded
+
+ def sliding_window(self, data, fields):
+ for sample in data:
+ key = sample["__key__"]
+
+ window_x, window_y = self.window_size
+ stride_x, stride_y = self.stride
+ padded_elems = {key: self.pad(sample[key], self.padding) for key in fields}
+
+ n_windows = 0
+ x = 0
+ y = 0
+ while True:
+ shape = None
+ windowed_fields = {}
+ for key in fields:
+ elem_padded = padded_elems[key]
+ if shape is None:
+ shape = elem_padded.shape
+ else:
+ if shape[-3:-1] != elem_padded.shape[-3:-1]:
+ raise ValueError("Element height, width after padding do not match")
+ windowed_fields[key] = elem_padded[..., y : y + window_y, x : x + window_x, :]
+
+ window_height, window_width = windowed_fields[key].shape[-3:-1]
+ assert (
+ window_y == window_height and window_x == window_width
+ ), f"Expected {window_y}, {window_x}, received {window_height}, {window_width}"
+
+ windowed_fields["__key__"] = f"{key}_{x - self.padding[0]}_{y - self.padding[1]}"
+ yield {**sample, **windowed_fields}
+ n_windows += 1
+
+ x += stride_x
+ if x >= shape[-2]:
+ y += stride_y
+ x = 0
+ if y >= shape[-3]:
+ break
+
+ if self.expected_n_windows is not None and self.expected_n_windows != n_windows:
+ raise ValueError(f"Expected {self.expected_n_windows} windows, but got {n_windows}")
+
+ @hooks.hook_implementation
+ def training_transform(self):
+ def apply_sliding_window(pipeline: webdataset.Processor):
+ if len(self._training_fields) > 0:
+ return pipeline.then(
+ functools.partial(self.sliding_window, fields=self._training_fields)
+ )
+ else:
+ return pipeline
+
+ return apply_sliding_window
+
+ @hooks.hook_implementation
+ def evaluation_fields(self):
+ return self._evaluation_fields
+
+ @hooks.hook_implementation
+ def evaluation_transform(self):
+ def apply_sliding_window(pipeline: webdataset.Processor):
+ if len(self._evaluation_fields) > 0:
+ return pipeline.then(
+ functools.partial(
+ self.sliding_window,
+ fields=self._evaluation_fields,
+ )
+ )
+ else:
+ return pipeline
+
+ return apply_sliding_window
+
+
+class MaskInstances(Plugin):
+ """Filter instances by masking non matching with NaN."""
+
+ def __init__(
+ self,
+ training_fields: Optional[List[str]] = None,
+ training_keys_to_keep: Optional[List[str]] = None,
+ evaluation_fields: Optional[List[str]] = None,
+ evaluation_keys_to_keep: Optional[List[str]] = None,
+ mask_video: bool = False,
+ ):
+ self._training_fields = training_fields
+ self.training_keys_to_keep = set(training_keys_to_keep) if training_keys_to_keep else None
+ self._evaluation_fields = evaluation_fields
+ self.evaluation_keys_to_keep = (
+ set(evaluation_keys_to_keep) if evaluation_keys_to_keep else None
+ )
+ self.mask_video = mask_video
+ if self.mask_video:
+ if self.training_keys_to_keep is not None:
+ self.train_video_key_to_frame_mapping = defaultdict(set)
+ for key in self.training_keys_to_keep:
+ video_key, frame = key.split("_")
+ self.train_video_key_to_frame_mapping[video_key].add(int(frame))
+ if self.evaluation_keys_to_keep is not None:
+ self.eval_video_key_to_frame_mapping = defaultdict(list)
+ for key in self.evaluation_keys_to_keep:
+ video_key, frame = key.split("_")
+ self.eval_video_key_to_frame_mapping[video_key].add(int(frame))
+
+ @hooks.hook_implementation
+ def training_fields(self):
+ return self._training_fields
+
+ def mask_instance(self, instance, fields, keys):
+ key = instance["__key__"]
+
+ if key not in keys:
+ for field in fields:
+ data = instance[field]
+ if isinstance(data, np.ndarray):
+ instance[field] = np.full_like(data, np.NaN)
+ elif isinstance(data, torch.Tensor):
+ instance[field] = torch.full_like(data, np.NaN)
+ else:
+ raise RuntimeError(f"Field {field} is of unexpected type {type(data)}.")
+ return instance
+
+ def mask_instance_video(self, instance, fields, video_key_to_frame_mapping):
+ key = instance["__key__"]
+ output = instance.copy()
+ for field in fields:
+ data = instance[field]
+ if isinstance(data, np.ndarray):
+ output[field] = np.full_like(data, np.NaN)
+ elif isinstance(data, torch.Tensor):
+ output[field] = torch.full_like(data, np.NaN)
+ else:
+ raise RuntimeError(f"Field {field} is of unexpected type {type(data)}.")
+
+ # We need to do some special handling here due to the strided decoding.
+ # This is not really nice, but fixing it nicely would require significantly
+ # more work for which we do not have the time at the moment.
+ if "decoded_indices" in instance.keys():
+ # Input comes from strided decoding, we thus need to adapt
+ # key and frames.
+ key, _ = key.split("_") # Get video key.
+ key = str(int(key))
+ if key in video_key_to_frame_mapping.keys():
+ frames_to_keep = video_key_to_frame_mapping[key]
+ decoded_indices = instance["decoded_indices"]
+ frames_to_keep = [index for index in decoded_indices if index in frames_to_keep]
+ for field in fields:
+ data = instance[field]
+ output[field][frames_to_keep] = data[frames_to_keep]
+ else:
+ if key in video_key_to_frame_mapping.keys():
+ frames_to_keep = video_key_to_frame_mapping[key]
+ for field in fields:
+ data = instance[field]
+ output[field][frames_to_keep] = data[frames_to_keep]
+ return output
+
+ @hooks.hook_implementation
+ def training_transform(self):
+ def subsample_with_masking(pipeline: webdataset.Processor):
+ if self._training_fields:
+ if self.mask_video:
+ return pipeline.map(
+ functools.partial(
+ self.mask_instance_video,
+ fields=self._training_fields,
+ video_key_to_frame_mapping=self.train_video_key_to_frame_mapping,
+ )
+ )
+ else:
+ return pipeline.map(
+ functools.partial(
+ self.mask_instance,
+ fields=self._training_fields,
+ keys=self.training_keys_to_keep,
+ )
+ )
+ else:
+ return pipeline
+
+ return subsample_with_masking
+
+ @hooks.hook_implementation
+ def evaluation_fields(self):
+ return self._evaluation_fields
+
+ @hooks.hook_implementation
+ def evaluation_transform(self):
+ def subsample_with_masking(pipeline: webdataset.Processor):
+ if self._evaluation_fields:
+ if self.mask_video:
+ return pipeline.map(
+ functools.partial(
+ self.mask_instance_video,
+ fields=self._evaluation_fields,
+ video_key_to_frame_mapping=self.eval_video_key_to_frame_mapping,
+ )
+ )
+ else:
+ return pipeline.map(
+ functools.partial(
+ self.mask_instance,
+ fields=self._evaluation_fields,
+ keys=self.evaluation_keys_to_keep,
+ )
+ )
+ else:
+ return pipeline
+
+ return subsample_with_masking
+
+
+class FlattenVideoToImage(Plugin):
+ def __init__(
+ self,
+ training_fields: Sequence[str],
+ evaluation_fields: Sequence[str],
+ shuffle_buffer_size: int = 0,
+ ):
+ """Flatten input video tensors into images.
+
+ Args:
+ training_fields: The fields that should be considered video data and thus sliced
+ according to the frame sampling during training.
+ evaluation_fields: The fields that should be considered video data and thus sliced
+ according to the frame sampling during evaluation.
+ shuffle_buffer_size: Size of shuffle buffer used during training. An additional
+ shuffling step ensures each batch contains a diverse set of images and not only
+ images from the same video.
+ """
+ self._training_fields = tuple(training_fields)
+ self._evaluation_fields = tuple(evaluation_fields)
+ self.shuffle_buffer_size = shuffle_buffer_size
+
+ def flatten_video(self, data, fields):
+ """Sample frames deterministically from generator of videos using the __key__ field."""
+ for sample in data:
+ # Initialize random number generator dependent on instance key. This should make the
+ # sampling process deterministic, which is useful when sampling frames for the
+ # validation/test data.
+ key = sample["__key__"]
+ # TODO (hornmax): We assume all fields to have the same size. I do not want to check
+ # this here as it seems a bit verbose.
+ n_frames = sample[fields[0]].shape[0]
+
+ for frame in range(n_frames):
+ # Slice the fields according to the frame.
+ sliced_fields = {field: sample[field][frame] for field in fields}
+ # Leave all fields besides the sliced ones as before, augment the __key__ field to
+ # include the frame number.
+ sliced_fields["__key__"] = f"{key}_{frame}"
+ yield {**sample, **sliced_fields}
+
+ @hooks.hook_implementation
+ def training_fields(self):
+ return self._training_fields
+
+ @hooks.hook_implementation
+ def training_transform(self):
+ def flatten_video(pipeline: webdataset.Processor):
+ if len(self._training_fields) > 0:
+ return pipeline.then(
+ functools.partial(self.flatten_video, fields=self._training_fields)
+ ).shuffle(self.shuffle_buffer_size)
+ else:
+ return pipeline
+
+ return flatten_video
+
+ @hooks.hook_implementation
+ def evaluation_fields(self):
+ return self._evaluation_fields
+
+ @hooks.hook_implementation
+ def evaluation_transform(self):
+ def flatten_video(pipeline: webdataset.Processor):
+ if len(self._evaluation_fields) > 0:
+ return pipeline.then(
+ functools.partial(
+ self.flatten_video,
+ fields=self._evaluation_fields,
+ )
+ )
+ else:
+ return pipeline
+
+ return flatten_video
+
+def transform_lift_dict(elements: dict, *, transform_dict, element_key: str, element_key_to_lift: List[str], to_del):
+ """Utility function to fix issues with pickling."""
+ # print("In transform_lift_dict")
+ if not to_del:
+ for key in element_key_to_lift:
+ # print("before transform")
+ # print(key, elements[element_key][key].shape)
+ elements[key] = transform_dict[key](elements[element_key][key])
+ # if isinstance(elements[key], np.ndarray):
+ # print("after transform")
+ # print(key, elements[key].shape)
+ return elements
+ else:
+ for key in element_key_to_lift:
+ # print("before transform")
+ # print(key, elements[element_key][key].shape)
+ elements[key] = transform_dict[key](elements[element_key][key])
+ # if isinstance(elements[key], np.ndarray):
+ # print("after transform")
+ # print(key, elements[key].shape)
+ elements.pop(element_key, None)
+ # print("element_key.keys() in transform_lift_dict")
+ # print(element_key.keys())
+ # import ipdb
+ # ipdb.set_trace()
+ return elements
+
+class SingleElementPreprocessingLiftDict(Plugin):
+ """Preprocessing of a single element in the input data.
+
+ This is useful to build preprocessing pipelines based on existing element transformations such
+ as those provided by torchvision. The element can optionally be duplicated and stored under a
+ different key after the transformation by specifying `duplicate_key`. This is useful to further
+ preprocess this element in different ways afterwards.
+ """
+
+ def __init__(
+ self,
+ training_transform: Dict[str, Callable],
+ evaluation_transform: Dict[str, Callable],
+ element_key: str = "image",
+ element_key_to_lift: List[str] = None,
+ to_del = False
+ ):
+
+ self._training_transform = training_transform
+ self._evaluation_transform = evaluation_transform
+ self.element_key = element_key
+ self.element_key_to_lift = element_key_to_lift
+ self.to_del = to_del
+
+
+ @hooks.hook_implementation
+ def training_fields(self):
+ return (self.element_key, )
+
+
+
+ @hooks.hook_implementation
+ def training_transform(self):
+ if self._training_transform:
+ # print("self._training_transform")
+ # for key, value in self._training_transform.items():
+ # print(key, value)
+ def transform(pipeline: webdataset.Processor):
+ transform_func = functools.partial(
+ transform_lift_dict,
+ transform_dict=self._training_transform,
+ element_key=self.element_key,
+ element_key_to_lift=self.element_key_to_lift,
+ to_del = self.to_del
+ )
+ return pipeline.map(transform_func)
+ return transform
+ else:
+ return None
+
+ @hooks.hook_implementation
+ def evaluation_fields(self):
+ return (self.element_key, )
+
+ @hooks.hook_implementation
+ def evaluation_transform(self):
+
+ if self._evaluation_transform:
+ # print("self._evaluation_transform")
+ # for key, value in self._evaluation_transform.items():
+ # print(key, value)
+ def transform(pipeline: webdataset.Processor):
+ transform_func = functools.partial(
+ transform_lift_dict,
+ transform_dict=self._evaluation_transform,
+ element_key=self.element_key,
+ element_key_to_lift=self.element_key_to_lift,
+ to_del = self.to_del
+ )
+ return pipeline.map(transform_func)
+ return transform
+ else:
+ return None
+
+class SequenceSampleFramesFromVideo(Plugin):
+ def __init__(
+ self,
+ n_frames_per_video: int,
+ training_fields: Sequence[str],
+ evaluation_fields: Sequence[str],
+ dim: int = 0,
+ seed: int = 39480234,
+ per_epoch: bool = False,
+ shuffle_buffer_size: int = 1000,
+ n_eval_frames_per_video: Optional[int] = None,
+ ):
+ """Sample frames from input tensors.
+
+ Args:
+ n_frames_per_video: Number of frames per video to sample. -1 indicates that all frames
+ should be sampled.
+ training_fields: The fields that should be considered video data and thus sliced
+ according to the frame sampling during training.
+ evaluation_fields: The fields that should be considered video data and thus sliced
+ according to the frame sampling during evaluation.
+ dim: The dimension along which to slice the tensors.
+ seed: Random number generator seed to deterministic sampling during evaluation.
+ per_epoch: Sampling of frames over epochs, this ensures that after
+ n_frames / n_frames_per_video epochs all frames have been seen at least once.
+ In the case of uneven division, some frames will be seen more than once.
+ shuffle_buffer_size: Size of shuffle buffer used during training. An additional
+ shuffling step ensures each batch contains a diverse set of images and not only
+ images from the same video.
+ n_eval_frames_per_video: Number of frames per video to sample on the evaluation splits.
+ """
+ self.n_frames_per_video = n_frames_per_video
+ self._training_fields = tuple(training_fields)
+ self._evaluation_fields = tuple(evaluation_fields)
+ self.dim = dim
+ self.seed = seed
+ self.per_epoch = per_epoch
+ self.shuffle_buffer_size = shuffle_buffer_size
+ if n_eval_frames_per_video is not None:
+ self.n_eval_frames_per_video = n_eval_frames_per_video
+ else:
+ self.n_eval_frames_per_video = n_frames_per_video
+
+ def slice_data(self, data, index: int):
+ """Small utility method to slice a numpy array along a specified axis."""
+ n_dims_before = self.dim
+ n_dims_after = data.ndim - 1 - self.dim
+ slices = (slice(None),) * n_dims_before + (index,) + (slice(None),) * n_dims_after
+ return data[slices]
+
+ def sample_frames_using_key(self, data, fields, seed, n_frames_per_video):
+ """Sample frames deterministically from generator of videos using the __key__ field."""
+ for sample in data:
+ key = sample["__key__"]
+ n_frames = sample[fields[0]].shape[self.dim]
+
+ splitted_fields = [
+ np.array_split(
+ sample[field],
+ range(self.n_consecutive_frames, n_frames, self.n_consecutive_frames),
+ axis=self.dim,
+ )
+ for field in fields
+ ]
+
+ for i, slices in enumerate(zip(*splitted_fields)):
+ if self.drop_last and slices[0].shape[self.dim] < self.n_consecutive_frames:
+ # Last slice of not equally divisible input, discard.
+ continue
+
+ sliced_fields = dict(zip(fields, slices))
+ sliced_fields["__key__"] = f"{key}_{i}"
+ yield {**sample, **sliced_fields}
+
+ for sample in data:
+ key = sample["__key__"]
+
+
+ n_frames = sample[fields[0]].shape[self.dim]
+ frames_per_video = self.n_frames_per_video if self.n_frames_per_video != -1 else n_frames
+ slices_id_start = range(0, n_frames, self.n_consecutive_frames)
+
+ if self.per_epoch and self.n_frames_per_video != -1:
+ n_different_epochs_per_seed = int(math.ceil(n_frames / frames_per_video))
+ try:
+ epoch = int(os.environ["WDS_EPOCH"])
+ except KeyError:
+ raise RuntimeError(
+ "Using SampleFramesFromVideo with stratify=True "
+ "requires `WDS_EPOCH` to be set."
+ )
+ # Only update the seed after n_frames / n_frames_per_video epochs.
+ # This ensures that we get the same random order of frames until
+ # we have sampled all of them.
+ rand = np.random.RandomState(
+ int(key) + seed + (epoch // n_different_epochs_per_seed)
+ )
+ indices = rand.permutation(n_frames)
+ selected_frames = indices[
+ epoch * self.n_frames_per_video : (epoch + 1) * self.n_frames_per_video
+ ].tolist()
+ if len(selected_frames) < self.n_frames_per_video:
+ # Input cannot be evenly split, take some frames from the first batch of frames.
+ n_missing = self.n_frames_per_video - len(selected_frames)
+ selected_frames.extend(indices[0:n_missing].tolist())
+ else:
+ rand = random.Random(int(key) + seed)
+ selected_frames = rand.sample(range(n_frames), k=frames_per_video)
+
+ for frame in selected_frames:
+ # Slice the fields according to the frame, we use copy in order to allow freeing of
+ # the original tensor.
+ sliced_fields = {
+ field: self.slice_data(sample[field], frame).copy() for field in fields
+ }
+ # Leave all fields besides the sliced ones as before, augment the __key__ field to
+ # include the frame number.
+ sliced_fields["__key__"] = f"{key}_{frame}"
+ yield {**sample, **sliced_fields}
+
+ # Delete fields to be sure we remove all references.
+ for field in fields:
+ del sample[field]
+
+ @hooks.hook_implementation
+ def training_fields(self):
+ return self._training_fields
+
+ @hooks.hook_implementation
+ def training_transform(self):
+ def apply_deterministic_sampling(pipeline: webdataset.Processor):
+ if len(self._training_fields) > 0:
+ return pipeline.then(
+ functools.partial(
+ self.sample_frames_using_key,
+ fields=self._training_fields,
+ seed=self.seed,
+ n_frames_per_video=self.n_frames_per_video,
+ )
+ ).shuffle(self.shuffle_buffer_size)
+ else:
+ return pipeline
+
+ return apply_deterministic_sampling
+
+ @hooks.hook_implementation
+ def evaluation_fields(self):
+ return self._evaluation_fields
+
+ @hooks.hook_implementation
+ def evaluation_transform(self):
+ def apply_deterministic_sampling(pipeline: webdataset.Processor):
+ if len(self._evaluation_fields) > 0:
+ return pipeline.then(
+ functools.partial(
+ self.sample_frames_using_key,
+ fields=self._evaluation_fields,
+ seed=self.seed + 1,
+ n_frames_per_video=self.n_eval_frames_per_video,
+ )
+ )
+ else:
+ return pipeline
+
+ return apply_deterministic_sampling
diff --git a/ocl/predictor.py b/ocl/predictor.py
new file mode 100644
index 0000000..9f89db2
--- /dev/null
+++ b/ocl/predictor.py
@@ -0,0 +1,85 @@
+import torch
+import torchvision
+from torch import nn
+
+
+class Predictor(nn.Module):
+ def __init__(
+ self,
+ embed_dim: int = 128,
+ num_heads: int = 4,
+ qkv_size: int = 128,
+ mlp_size: int = 256,
+ pre_norm: bool = False,
+ ):
+ nn.Module.__init__(self)
+
+ self.embed_dim = embed_dim
+ self.qkv_size = qkv_size
+ self.mlp_size = mlp_size
+ self.num_heads = num_heads
+ self.pre_norm = pre_norm
+ self.MHA = nn.MultiheadAttention(embed_dim, num_heads)
+
+ self.head_dim = qkv_size // num_heads
+ self.mlp = torchvision.ops.MLP(embed_dim, [mlp_size, embed_dim])
+ # layernorms
+ self.layernorm_query = nn.LayerNorm(embed_dim, eps=1e-6)
+ self.layernorm_mlp = nn.LayerNorm(embed_dim, eps=1e-6)
+ # weights
+ self.dense_q = nn.Linear(embed_dim, qkv_size)
+ self.dense_k = nn.Linear(embed_dim, qkv_size)
+ self.dense_v = nn.Linear(embed_dim, qkv_size)
+ if self.num_heads > 1:
+ self.dense_o = nn.Linear(qkv_size, embed_dim)
+ self.multi_head = True
+ else:
+ self.multi_head = False
+
+ def forward(
+ self, object_features: torch.Tensor
+ ): # TODO: add general attention for q, k, v, not just for x = qkv
+ assert object_features.ndim == 3
+ B, L, _ = object_features.shape
+ head_dim = self.embed_dim // self.num_heads
+
+ if self.pre_norm:
+ # Self-attention.
+ x = self.layernorm_query(object_features)
+ q = self.dense_q(x).view(B, L, self.num_heads, head_dim)
+ k = self.dense_k(x).view(B, L, self.num_heads, head_dim)
+ v = self.dense_v(x).view(B, L, self.num_heads, head_dim)
+ x, _ = self.MHA(q, k, v)
+ if self.multi_head:
+ x = self.dense_o(x.reshape(B, L, self.qkv_size)).view(B, L, self.embed_dim)
+ else:
+ x = x.squeeze(-2)
+ x = x + object_features
+
+ y = x
+
+ # MLP
+ z = self.layernorm_mlp(y)
+ z = self.mlp(z)
+ z = z + y
+ else:
+ # Self-attention on queries.
+ x = object_features
+ q = self.dense_q(x).view(B, L, self.num_heads, head_dim)
+ k = self.dense_k(x).view(B, L, self.num_heads, head_dim)
+ v = self.dense_v(x).view(B, L, self.num_heads, head_dim)
+ x, _ = self.MHA(q, k, v)
+ if self.multi_head:
+ x = self.dense_o(x.reshape(B, L, self.qkv_size)).view(B, L, self.embed_dim)
+ else:
+ x = x.squeeze(-2)
+ x = x + object_features
+ x = self.layernorm_query(x)
+
+ y = x
+
+ # MLP
+ z = self.mlp(y)
+ z = z + y
+ z = self.layernorm_mlp(z)
+ return z
diff --git a/ocl/preprocessing.py b/ocl/preprocessing.py
new file mode 100644
index 0000000..687c502
--- /dev/null
+++ b/ocl/preprocessing.py
@@ -0,0 +1,1221 @@
+"""Data preprocessing functions."""
+import random
+from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
+
+import numpy
+import torch
+from torchvision import transforms
+from torchvision.ops import masks_to_boxes
+
+from ocl.utils.bboxes import box_xyxy_to_cxcywh
+
+
+class DropEntries:
+ """Drop entries from data dictionary."""
+
+ def __init__(self, keys: List[str]):
+ self.keys = tuple(keys)
+
+ def __call__(self, data: Dict[str, Any]):
+ return {k: v for k, v in data.items() if k not in self.keys}
+
+
+class CheckFormat:
+ """Check format of data."""
+
+ def __init__(self, shape: List[int], one_hot: bool = False, class_dim: int = 0):
+ self.shape = tuple(shape)
+ self.one_hot = one_hot
+ self.class_dim = class_dim
+
+ def __call__(self, data: torch.Tensor) -> torch.Tensor:
+ if data.shape != self.shape:
+ raise ValueError(f"Expected shape to be {self.shape}, but is {data.shape}")
+
+ if self.one_hot:
+ if not torch.all(data.sum(self.class_dim) == 1):
+ raise ValueError("Data is not one-hot")
+
+ return data
+
+
+class CompressMask:
+ def __call__(self, mask: numpy.ndarray) -> numpy.ndarray:
+ non_empty = numpy.any(mask != 0, axis=(0, 2, 3))
+ # Preserve first object beeing empty. This is often considered the
+ # foreground mask and sometimes ignored.
+ last_nonempty_index = len(non_empty) - non_empty[::-1].argmax()
+ input_arr = mask[:, :last_nonempty_index]
+ n_objects = input_arr.shape[1]
+ dtype = numpy.uint8
+ if n_objects > 8:
+ dtype = numpy.uint16
+ if n_objects > 16:
+ dtype = numpy.uint32
+ if n_objects > 32:
+ dtype = numpy.uint64
+ if n_objects > 64:
+ raise RuntimeError("We do not support more than 64 objects at the moment.")
+
+ object_flag = (1 << numpy.arange(n_objects, dtype=dtype))[None, :, None, None]
+ output_arr = numpy.sum(input_arr.astype(dtype) * object_flag, axis=1).astype(dtype)
+ return output_arr
+
+
+class CompressedMaskToTensor:
+ def __call__(self, compressed_mask: numpy.ndarray) -> torch.Tensor:
+ maximum_value = numpy.max(compressed_mask)
+ n_objects = 0
+ while maximum_value > 0:
+ maximum_value //= 2
+ n_objects += 1
+
+ if n_objects == 0:
+ # Cover edge case of no objects.
+ n_objects = 1
+
+ squeeze = False
+ if len(compressed_mask.shape) == 2:
+ compressed_mask = compressed_mask[None, ...]
+ squeeze = True
+ # Not really sure why we need to invert the order here, but it seems
+ # to be necessary for the index to remain consistent between compression
+ # and decompression.
+ is_bit_active = (1 << numpy.arange(n_objects, dtype=compressed_mask.dtype))[
+ None, :, None, None
+ ]
+ expanded_mask = (compressed_mask[:, None, :, :] & is_bit_active) > 0
+ if squeeze:
+ expanded_mask = numpy.squeeze(expanded_mask, axis=0)
+ return torch.from_numpy(expanded_mask).to(torch.float32)
+
+
+class MaskToTensor:
+ """Convert a segmentation mask numpy array to a tensor.
+
+ Mask is assumed to be of shape (..., K, H, W, 1), i.e. one-hot encoded with K classes and any
+ number of leading dimensions. Returned tensor is of shape (..., K, H, W), containing binary
+ entries.
+ """
+
+ def __init__(self, singleton_dim_last: bool = True):
+ self.singleton_dim_last = singleton_dim_last
+
+ def __call__(self, mask: numpy.ndarray) -> torch.Tensor:
+ mask_binary = mask > 0.0
+ if self.singleton_dim_last:
+ assert mask_binary.shape[-1] == 1
+ return torch.from_numpy(mask_binary).squeeze(-1).to(torch.float32)
+ else:
+ return torch.from_numpy(mask_binary).to(torch.float32)
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}()"
+
+
+class DenseMaskToTensor:
+ """Convert a dense segmentation mask numpy array to a tensor.
+
+ Mask is assumed to be of shape (..., K, H, W, 1), i.e. densely encoded with K classes and any
+ number of leading dimensions. Returned tensor is of shape (..., K, H, W).
+ """
+
+ def __call__(self, mask: numpy.ndarray) -> torch.Tensor:
+ assert mask.shape[-1] == 1
+ return torch.from_numpy(mask).squeeze(-1).to(torch.uint8)
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}()"
+
+
+class MultiMaskToTensor:
+ """Discretize mask, where multiple objects are partially masked into an exclusive binary mask."""
+
+ def __init__(self, axis: int = -4):
+ self.axis = axis
+
+ def __call__(self, mask: numpy.ndarray) -> torch.Tensor:
+ int_mask = numpy.argmax(mask, axis=self.axis).squeeze(-1)
+ out_mask = torch.nn.functional.one_hot(torch.from_numpy(int_mask), mask.shape[self.axis])
+ # Ensure the object axis is again at the same location.
+ # We operate on the shape prior to squeezing for axis to be consistent.
+ last_index = len(out_mask.shape) - 1
+ indices = list(range(len(out_mask.shape) + 1))
+ indices.insert(self.axis, last_index)
+ indices = indices[:-2] # Remove last indices as they are squeezed or inserted.
+ out_mask = out_mask.permute(*indices).to(torch.float32)
+ return out_mask
+
+
+class IntegerToOneHotMask:
+ """Convert an integer mask to a one-hot mask.
+
+ Integer masks are masks where the instance ID is written into the mask.
+ This transform expands them to a one-hot encoding.
+
+ Args:
+ ignore_typical_background: Ignore pixels where the mask is zero or 255.
+ This often corresponds to the background or to the segmentation boundary.
+ """
+
+ def __init__(self, ignore_typical_background=True, output_axis=-4, max_instances=None):
+ self.ignore_typical_background = ignore_typical_background
+ self.output_axis = output_axis
+ self.max_instances = max_instances
+
+ def __call__(self, array: numpy.array):
+ max_value = array.max()
+ if self.ignore_typical_background:
+ if max_value == 255:
+ # Replace 255 with zero, both are ignored.
+ array[array == 255] = 0
+ max_value = array.max()
+ max_instances = self.max_instances if self.max_instances else max_value
+ to_one_hot = numpy.concatenate(
+ [
+ numpy.zeros((1, max_instances), dtype=numpy.uint8),
+ numpy.eye(max_instances, dtype=numpy.uint8),
+ ],
+ axis=0,
+ )
+ else:
+ max_instances = self.max_instances if self.max_instances else max_value
+ to_one_hot = numpy.eye(max_instances + 1, dtype=numpy.uint8)
+ return numpy.moveaxis(to_one_hot[array], -1, self.output_axis)
+
+
+class VOCInstanceMasksToDenseMasks:
+ """Convert a segmentation mask with integer encoding into a one-hot segmentation mask.
+
+ We use this transform as Pascal VOC segmentatation and object annotations seems to not
+ be aligned.
+ """
+
+ def __init__(
+ self,
+ instance_mask_key: str = "segmentation-instance",
+ class_mask_key: str = "segmentation-class",
+ classes_key: str = "instance_category",
+ ignore_mask_key: str = "ignore_mask",
+ instance_axis: int = -4,
+ ):
+ self.instance_mask_key = instance_mask_key
+ self.class_mask_key = class_mask_key
+ self.classes_key = classes_key
+ self.ignore_mask_key = ignore_mask_key
+ self.instance_axis = instance_axis
+
+ def __call__(self, data: Dict[str, Any]):
+ data[self.ignore_mask_key] = (data[self.class_mask_key] == 255)[None] # 1 x H x W x 1
+ expanded_segmentation_mask = data[self.instance_mask_key] * numpy.expand_dims(
+ data[self.class_mask_key], axis=self.instance_axis
+ )
+ assert expanded_segmentation_mask.max() != 255
+ data[self.instance_mask_key] = expanded_segmentation_mask
+ classes = []
+ for instance_slice in numpy.rollaxis(expanded_segmentation_mask, self.instance_axis):
+ unique_values = numpy.unique(instance_slice)
+ assert len(unique_values) == 2 # Should contain 0 and class id.
+ classes.append(unique_values[1])
+ data[self.classes_key] = numpy.array(classes)
+
+ return data
+
+
+class AddImageSize:
+ """Add height and width of image as data entry.
+
+ Args:
+ key: Key of image.
+ target_key: Key under which to store size.
+ """
+
+ def __init__(self, key: str = "image", target_key: str = "image_size"):
+ self.key = key
+ self.target_key = target_key
+
+ def __call__(self, data: Dict[str, Any]):
+ height, width, _ = data[self.key].shape
+ data[self.target_key] = numpy.array([height, width], dtype=numpy.int64)
+ return data
+
+
+class AddEmptyMasks:
+ """Add empty masks to data if the data does not include them already.
+
+ Args:
+ keys: One or several keys of empty masks to be added.
+ take_size_from: Key of element whose height and width is used to create mask. Element is
+ assumed to have shape of (H, W, C).
+ """
+
+ def __init__(self, mask_keys: Union[str, Sequence[str]], take_size_from: str = "image"):
+ if isinstance(mask_keys, str):
+ self.mask_keys = (mask_keys,)
+ else:
+ self.mask_keys = tuple(mask_keys)
+ self.source_key = take_size_from
+
+ def __call__(self, data: Dict[str, Any]):
+ height, width, _ = data[self.source_key].shape
+ for key in self.mask_keys:
+ if key not in data:
+ data[key] = numpy.zeros((1, height, width, 1), dtype=numpy.uint8)
+
+ return data
+
+
+class AddEmptyBboxes:
+ """Add empty bounding boxes to data if the data does not include them already.
+
+ Args:
+ keys: One or several keys of empty boxes to be added.
+ empty_value: Value of the empty box at all coordinates.
+ """
+
+ def __init__(self, keys: Union[str, Sequence[str]] = "instance_bbox", empty_value: float = -1.0):
+ if isinstance(keys, str):
+ self.keys = (keys,)
+ else:
+ self.keys = tuple(keys)
+ self.empty_value = empty_value
+
+ def __call__(self, data: Dict[str, Any]):
+ for key in self.keys:
+ if key not in data:
+ data[key] = numpy.ones((1, 4), dtype=numpy.float32) * self.empty_value
+
+ return data
+
+
+class CanonicalizeBboxes:
+ """Convert bounding boxes to canonical (x1, y1, x2, y2) format.
+
+ Args:
+ key: Key of bounding box, assumed to have shape K x 4.
+ format: Format of bounding boxes. Either "xywh" or "yxyx".
+ """
+
+ def __init__(self, key: str = "instance_bbox", format: str = "xywh"):
+ self.key = key
+
+ self.format_xywh = False
+ self.format_yxyx = False
+ if format == "xywh":
+ self.format_xywh = True
+ elif format == "yxyx":
+ self.format_yxyx = True
+ else:
+ raise ValueError(f"Unknown input format `{format}`")
+
+ def __call__(self, data: Dict[str, Any]):
+ if self.key not in data:
+ return data
+
+ bboxes = data[self.key]
+ if self.format_xywh:
+ x1, y1, w, h = numpy.split(bboxes, 4, axis=1)
+ x2 = x1 + w
+ y2 = y1 + h
+ elif self.format_yxyx:
+ y1, x1, y2, x2 = numpy.split(bboxes, 4, axis=1)
+
+ data[self.key] = numpy.concatenate((x1, y1, x2, y2), axis=1)
+
+ return data
+
+
+class RescaleBboxes:
+ """Rescale bounding boxes by size taken from data.
+
+ Bounding boxes are assumed to have format (x1, y1, x2, y2). The rescaled box is
+ (x1 * width, y1 * height, x2 * width, y2 * height).
+
+ Args:
+ key: Key of bounding box, assumed to have shape K x 4.
+ take_size_from: Key of element to take the size for rescaling from, assumed to have shape
+ H x W x C.
+ """
+
+ def __init__(self, key: str = "instance_bbox", take_size_from: str = "image"):
+ self.key = key
+ self.take_size_from = take_size_from
+
+ def __call__(self, data: Dict[str, Any]):
+ if self.key not in data:
+ return data
+
+ height, width, _ = data[self.take_size_from].shape
+ scaling = numpy.array([[width, height, width, height]], dtype=numpy.float32)
+ data[self.key] = data[self.key] * scaling
+
+ return data
+
+
+def expand_dense_mask(mask: numpy.ndarray) -> numpy.ndarray:
+ """Convert dense segmentation mask to one where each class occupies one dimension.
+
+ Args:
+ mask: Densely encoded segmentation mask of shape 1 x H x W x 1.
+
+ Returns: Densely encoded segmentation mask of shape K x H x W x 1, where K is the
+ number of classes in the mask. Zero is taken to indicate an unoccupied pixel.
+ """
+ classes = numpy.unique(mask)[:, None, None, None]
+ mask = (classes == mask) * classes
+
+ # Strip empty class, but only if there is something else in the mask
+ if classes[0].squeeze() == 0 and len(classes) != 1:
+ mask = mask[1:]
+
+ return mask
+
+
+class AddSegmentationMaskFromInstanceMask:
+ """Convert instance to segmentation masks by joining instances with the same category.
+
+ Overlaps of instances of different classes are resolved by taking the class with the higher class
+ id.
+ """
+
+ def __init__(
+ self,
+ instance_mask_key: str = "instance_mask",
+ target_key: str = "segmentation_mask",
+ ):
+ self.instance_mask_key = instance_mask_key
+ self.target_key = target_key
+
+ @staticmethod
+ def convert(instance_mask: numpy.ndarray) -> numpy.ndarray:
+ """Convert instance to segmentation mask.
+
+ Args:
+ instance_mask: Densely encoded instance masks of shape I x H x W x 1, where I is the
+ number of instances.
+ """
+ # Reduce instance mask to single dimension
+ instance_mask = instance_mask.max(axis=0, keepdims=True)
+
+ return expand_dense_mask(instance_mask)
+
+ def __call__(self, data: Dict[str, Any]):
+ if self.instance_mask_key not in data:
+ return data
+
+ data[self.target_key] = self.convert(data[self.instance_mask_key])
+
+ return data
+
+
+class RenameFields:
+ def __init__(self, mapping: Dict):
+ self.mapping = mapping
+
+ def __call__(self, d: Dict):
+ # Create shallow copy to avoid issues target key is already used.
+ out = d.copy()
+ for source, target in self.mapping.items():
+ out[target] = d[source]
+ return out
+
+
+class AddBBoxFromInstanceMasks:
+ """Convert instance mask to bounding box.
+
+ Args:
+ instance_mask_key: mask key name.
+ target_key: target key name.
+ """
+
+ def __init__(
+ self,
+ instance_mask_key: str = "mask",
+ video_id_key: str = "__key__", # not quite sure if this is the best key
+ target_box_key: str = "instance_bbox",
+ target_cls_key: str = "instance_cls",
+ target_id_key: str = "instance_id",
+ ):
+ self.instance_mask_key = instance_mask_key
+ self.video_id_key = video_id_key
+ self.target_box_key = target_box_key
+ self.target_cls_key = target_cls_key
+ self.target_id_key = target_id_key
+
+ @staticmethod
+ def convert(instance_mask: numpy.ndarray, video_id: numpy.ndarray) -> numpy.ndarray:
+ num_frame, num_instance, height, width, _ = instance_mask.shape
+
+ # Convert to binary mask
+ binary_mask = instance_mask > 0
+ # Filter background. TODO: now we assume the first mask for each video is background.
+ # Might not apply to every dataset
+ binary_mask = binary_mask[:, 1:]
+ num_instance -= 1
+ binary_mask = (
+ torch.tensor(binary_mask).squeeze().view(num_frame * num_instance, height, width)
+ )
+ # Filter empty masks
+ non_empty_mask_idx = torch.where(binary_mask.sum(-1).sum(-1) > 0)[0]
+ empty_mask_idx = torch.where(binary_mask.sum(-1).sum(-1) == 0)[0]
+ non_empty_binary_mask = binary_mask[non_empty_mask_idx]
+ non_empty_bboxes = masks_to_boxes(non_empty_binary_mask)
+
+ # Turn box into cxcyhw
+ bboxes = torch.zeros(num_frame * num_instance, 4)
+ non_empty_bboxes = box_xyxy_to_cxcywh(non_empty_bboxes)
+ bboxes[non_empty_mask_idx] = non_empty_bboxes
+ # normalized to 0,1
+ # Make sure width and height are correct
+ bboxes[:, 0::2] = bboxes[:, 0::2] / width
+ bboxes[:, 1::2] = bboxes[:, 1::2] / height
+ bboxes = bboxes.view(num_frame, num_instance, 4).squeeze(-1).to(torch.float32)
+
+ # class
+ # -1 is background or no object, 0 is the first object class
+ instance_cls = torch.ones(num_frame * num_instance, 1) * -1
+ instance_cls[non_empty_mask_idx] = 0
+ instance_cls = instance_cls.view(num_frame, num_instance, 1).squeeze(-1).to(torch.long)
+
+ # ID
+ instance_id = torch.range(0, num_instance - 1)[None, :, None].repeat(num_frame, 1, 1)
+ instance_id = instance_id.view(num_frame * num_instance, 1)
+ instance_id[empty_mask_idx] = -1
+ instance_id = instance_id.view(num_frame, num_instance, 1).squeeze(-1).to(torch.long)
+
+ return bboxes, instance_cls, instance_id
+
+ def __call__(self, data: Dict[str, Any]):
+ if self.instance_mask_key not in data:
+ return data
+
+ bboxes, instance_cls, instance_id = self.convert(
+ data[self.instance_mask_key], data[self.video_id_key]
+ )
+ data[self.target_box_key] = bboxes
+ data[self.target_cls_key] = instance_cls
+ data[self.target_id_key] = instance_id
+ return data
+
+
+class InstanceMasksToDenseMasks:
+ """Convert binary instance masks to dense masks, i.e. where the mask value encodes the class id.
+
+ Class ids are taken from a list containing a class id per instance.
+ """
+
+ def __init__(
+ self,
+ instance_mask_key: str = "instance_mask",
+ category_key: str = "instance_category",
+ ):
+ self.instance_mask_key = instance_mask_key
+ self.category_key = category_key
+
+ @staticmethod
+ def convert(instance_mask: numpy.ndarray, categories: numpy.ndarray) -> numpy.ndarray:
+ if numpy.min(categories) <= 0:
+ raise ValueError("Detected category smaller equal than 0 in instance masks.")
+ if numpy.max(categories) > 255:
+ raise ValueError(
+ "Detected category greater than 255 in instance masks. This does not fit in uint8."
+ )
+
+ categories = categories[:, None, None, None]
+ return (instance_mask * categories).astype(numpy.uint8)
+
+ def __call__(self, data: Dict[str, Any]):
+ if self.instance_mask_key not in data:
+ return data
+
+ data[self.instance_mask_key] = self.convert(
+ data[self.instance_mask_key], data[self.category_key]
+ )
+
+ return data
+
+
+class MergeCocoThingsAndStuff:
+ """Merge COCO things and stuff segmentation masks.
+
+ Args:
+ things_key: Key to things instance mask. Mask is assumed to be densely encoded, i.e.
+ the mask value encodes the class id, of shape I x H x W x 1, where I is the number of
+ things instances.
+ stuff_key: Key to stuff segmentation mask. Mask is assumed to be densely encoded, i.e.
+ the mask value encodes the class id, of shape K x H x W x 1, where K is the number stuff
+ classes.
+ output_key: Key under which the merged mask is stored. Returns mask of shape L x H x W x 1,
+ where K <= L <= K + I.
+ include_crowd: Whether to include pixels marked as crowd with their class, or with class
+ zero.
+ """
+
+ def __init__(
+ self,
+ output_key: str,
+ things_key: str = "instance_mask",
+ stuff_key: str = "stuff_mask",
+ include_crowd: bool = False,
+ ):
+ self.things_key = things_key
+ self.stuff_key = stuff_key
+ self.output_key = output_key
+ self.include_crowd = include_crowd
+
+ def __call__(self, data: Dict[str, Any]):
+ if self.things_key in data:
+ things_instance_mask = data[self.things_key]
+ things_mask = things_instance_mask.max(axis=0, keepdims=True)
+ else:
+ things_mask = None
+
+ stuff_mask = data[self.stuff_key]
+ merged_mask = stuff_mask.max(axis=0, keepdims=True)
+
+ # In stuff annotations, thing pixels are encoded as class 183.
+ use_thing_mask = merged_mask == 183
+
+ if things_mask is not None:
+ if self.include_crowd:
+ # In the stuff annotations, things marked with the "crowd" label are NOT encoded as
+ # class 183, but as class 0. We can take the value of the things mask for those
+ # pixels.
+ use_thing_mask |= merged_mask == 0
+ merged_mask[use_thing_mask] = things_mask[use_thing_mask]
+ else:
+ # No pixel should have value 183 if the things_mask does not exist, but convert it to
+ # zero anyways just to be sure.
+ merged_mask[use_thing_mask] = 0
+
+ data[self.output_key] = expand_dense_mask(merged_mask)
+
+ return data
+
+
+class FlowToTensor:
+ """Convert an optical flow numpy array to a tensor.
+
+ Flow is assumed to be of shape (..., H, W, 2), returned tensor is of shape (..., 2, H, W) to
+ match with VideoTensor format.
+ """
+
+ def __call__(self, flow: numpy.ndarray) -> torch.Tensor:
+ flow_torch = torch.from_numpy(flow.astype(float)).to(torch.float32)
+ return torch.moveaxis(flow_torch, -1, -3)
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}()"
+
+
+class ConvertCocoStuff164kMasks:
+ """Convert COCO-Stuff-164k PNG segmentation masks to our format.
+
+ Args:
+ output_key: Key under which the output mask is stored. Returns uint8 mask of shape
+ K x H x W x 1, where K is the number of classes in the image. Mask is densely encoded,
+ i.e. the mask values encode the class id.
+ stuffthings_key: Key to COCO-Stuff-164k PNG mask. Mask has shape H x W x 3.
+ ignore_key: Key under which the ignore mask is stored. Returns bool mask of shape
+ 1 x H x W x 1. Ignores pixels where PNG mask has value 255 (crowd).
+ drop_stuff: If true, remove all stuff classes (id >= 92), keeping only thing classes.
+ """
+
+ def __init__(
+ self,
+ output_key: str,
+ stuffthings_key: str = "stuffthings_mask",
+ ignore_key: str = "ignore_mask",
+ drop_stuff: bool = False,
+ ):
+ self.stuffthings_key = stuffthings_key
+ self.ignore_key = ignore_key
+ self.output_key = output_key
+ self.drop_stuff = drop_stuff
+
+ def __call__(self, data: Dict[str, Any]):
+ mask = data[self.stuffthings_key] # H x W x 3, mask is encoded as an image
+ assert mask.shape[-1] == 3
+ mask = mask[:, :, :1] # Take first channel, all channels are the same
+
+ ignore_mask = mask == 255
+
+ # In PNG annotations, classes occupy indices 0-181, shift by 1
+ mask = mask + 1
+ mask[ignore_mask] = 0
+
+ if self.drop_stuff:
+ mask[mask >= 92] = 0
+
+ data[self.ignore_key] = ignore_mask[None] # 1 x H x W x 1
+ data[self.output_key] = expand_dense_mask(mask[None]) # K x H x W x 1
+
+ return data
+
+
+class VideoToTensor:
+ """Convert a video numpy array of shape (T, H, W, C) to a torch tensor of shape (T, C, H, W)."""
+
+ def __call__(self, video):
+ """Convert a numpy array of a video into a torch tensor.
+
+ Assumes input is a numpy array of shape T x H x W x C (or T x H x W for monochrome videos)
+ and convert it into torch tensor of shape T x C x H x W in order to allow application of
+ Conv3D operations.
+ """
+ if isinstance(video, numpy.ndarray):
+ # Monochrome video such as mask
+ if video.ndim == 3:
+ video = video[..., None]
+
+ video = torch.from_numpy(video.transpose((0, 3, 1, 2))).contiguous()
+ # backward compatibility
+ if isinstance(video, torch.ByteTensor):
+ return video.to(dtype=torch.get_default_dtype()).div(255)
+ else:
+ return video
+ else:
+ # Should be torch tensor.
+ if video.ndim == 3:
+ video = video[..., None]
+
+ video = video.permute(0, 3, 1, 2).contiguous()
+ # backward compatibility
+ if isinstance(video, torch.ByteTensor):
+ return video.to(dtype=torch.get_default_dtype()).div(255)
+ else:
+ return video
+
+
+class ToSingleFrameVideo:
+ """Convert image in tensor format to video format by adding frame dimension with single element.
+
+ Converts C x H x W tensors into tensors of shape 1 x C x H x W.
+ """
+
+ def __call__(self, image):
+ return image.unsqueeze(0)
+
+
+class NormalizeVideo:
+ """Normalize a video tensor of shape (T, C, H, W)."""
+
+ def __init__(self, mean, std):
+ self.mean = torch.tensor(mean)[None, :, None, None]
+ self.std = torch.tensor(std)[None, :, None, None]
+
+ def __call__(self, video):
+ return (video - self.mean) / self.std
+
+
+class Denormalize(torch.nn.Module):
+ """Denormalize a tensor of shape (..., C, H, W) with any number of leading dimensions."""
+
+ def __init__(self, mean, std):
+ super().__init__()
+ self.register_buffer("mean", torch.tensor(mean)[:, None, None])
+ self.register_buffer("std", torch.tensor(std)[:, None, None])
+
+ def __call__(self, tensor):
+ return tensor * self.std + self.mean
+
+
+class ResizeNearestExact:
+ """Resize a tensor using mode nearest-exact.
+
+ This mode is not available in torchvision.transforms.Resize as of v0.12. This class was adapted
+ from torchvision.transforms.functional_tensor.resize.
+ """
+
+ def __init__(self, size: Union[int, List[int]], max_size: Optional[int] = None):
+ self.size = size
+ self.max_size = max_size
+
+ @staticmethod
+ def _cast_squeeze_in(
+ img: torch.Tensor, req_dtypes: List[torch.dtype]
+ ) -> Tuple[torch.Tensor, bool, bool, torch.dtype]:
+ need_squeeze = False
+ # make image NCHW
+ if img.ndim < 4:
+ img = img.unsqueeze(dim=0)
+ need_squeeze = True
+
+ out_dtype = img.dtype
+ need_cast = False
+ if out_dtype not in req_dtypes:
+ need_cast = True
+ req_dtype = req_dtypes[0]
+ img = img.to(req_dtype)
+ return img, need_cast, need_squeeze, out_dtype
+
+ @staticmethod
+ def _cast_squeeze_out(
+ img: torch.Tensor, need_cast: bool, need_squeeze: bool, out_dtype: torch.dtype
+ ) -> torch.Tensor:
+ if need_squeeze:
+ img = img.squeeze(dim=0)
+
+ if need_cast:
+ if out_dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
+ # it is better to round before cast
+ img = torch.round(img)
+ img = img.to(out_dtype)
+
+ return img
+
+ @staticmethod
+ def resize(img: torch.Tensor, size: Union[int, List[int]], max_size: Optional[int] = None):
+ h, w = img.shape[-2:]
+
+ if isinstance(size, int) or len(size) == 1: # specified size only for the smallest edge
+ short, long = (w, h) if w <= h else (h, w)
+ requested_new_short = size if isinstance(size, int) else size[0]
+
+ new_short, new_long = requested_new_short, int(requested_new_short * long / short)
+
+ if max_size is not None:
+ if max_size <= requested_new_short:
+ raise ValueError(
+ f"max_size = {max_size} must be strictly greater than the requested "
+ f"size for the smaller edge size = {size}"
+ )
+ if new_long > max_size:
+ new_short, new_long = int(max_size * new_short / new_long), max_size
+
+ new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
+
+ if (w, h) == (new_w, new_h):
+ return img
+ else: # specified both h and w
+ new_w, new_h = size[1], size[0]
+
+ img, need_cast, need_squeeze, out_dtype = ResizeNearestExact._cast_squeeze_in(
+ img, (torch.float32, torch.float64)
+ )
+
+ img = torch.nn.functional.interpolate(img, size=[new_h, new_w], mode="nearest-exact")
+
+ img = ResizeNearestExact._cast_squeeze_out(
+ img, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype
+ )
+
+ return img
+
+ def __call__(self, img: torch.Tensor) -> torch.Tensor:
+ return ResizeNearestExact.resize(img, self.size, self.max_size)
+
+
+class ConvertToCocoSuperclasses:
+ """Convert segmentation mask from COCO classes (183) to COCO superclasses (27)."""
+
+ ID_TO_SUPERCLASS_AND_NAME = {
+ 0: ("unlabeled", "unlabeled"),
+ 1: ("person", "person"),
+ 2: ("vehicle", "bicycle"),
+ 3: ("vehicle", "car"),
+ 4: ("vehicle", "motorcycle"),
+ 5: ("vehicle", "airplane"),
+ 6: ("vehicle", "bus"),
+ 7: ("vehicle", "train"),
+ 8: ("vehicle", "truck"),
+ 9: ("vehicle", "boat"),
+ 10: ("outdoor", "traffic light"),
+ 11: ("outdoor", "fire hydrant"),
+ 13: ("outdoor", "stop sign"),
+ 14: ("outdoor", "parking meter"),
+ 15: ("outdoor", "bench"),
+ 16: ("animal", "bird"),
+ 17: ("animal", "cat"),
+ 18: ("animal", "dog"),
+ 19: ("animal", "horse"),
+ 20: ("animal", "sheep"),
+ 21: ("animal", "cow"),
+ 22: ("animal", "elephant"),
+ 23: ("animal", "bear"),
+ 24: ("animal", "zebra"),
+ 25: ("animal", "giraffe"),
+ 27: ("accessory", "backpack"),
+ 28: ("accessory", "umbrella"),
+ 31: ("accessory", "handbag"),
+ 32: ("accessory", "tie"),
+ 33: ("accessory", "suitcase"),
+ 34: ("sports", "frisbee"),
+ 35: ("sports", "skis"),
+ 36: ("sports", "snowboard"),
+ 37: ("sports", "sports ball"),
+ 38: ("sports", "kite"),
+ 39: ("sports", "baseball bat"),
+ 40: ("sports", "baseball glove"),
+ 41: ("sports", "skateboard"),
+ 42: ("sports", "surfboard"),
+ 43: ("sports", "tennis racket"),
+ 44: ("kitchen", "bottle"),
+ 46: ("kitchen", "wine glass"),
+ 47: ("kitchen", "cup"),
+ 48: ("kitchen", "fork"),
+ 49: ("kitchen", "knife"),
+ 50: ("kitchen", "spoon"),
+ 51: ("kitchen", "bowl"),
+ 52: ("food", "banana"),
+ 53: ("food", "apple"),
+ 54: ("food", "sandwich"),
+ 55: ("food", "orange"),
+ 56: ("food", "broccoli"),
+ 57: ("food", "carrot"),
+ 58: ("food", "hot dog"),
+ 59: ("food", "pizza"),
+ 60: ("food", "donut"),
+ 61: ("food", "cake"),
+ 62: ("furniture", "chair"),
+ 63: ("furniture", "couch"),
+ 64: ("furniture", "potted plant"),
+ 65: ("furniture", "bed"),
+ 67: ("furniture", "dining table"),
+ 70: ("furniture", "toilet"),
+ 72: ("electronic", "tv"),
+ 73: ("electronic", "laptop"),
+ 74: ("electronic", "mouse"),
+ 75: ("electronic", "remote"),
+ 76: ("electronic", "keyboard"),
+ 77: ("electronic", "cell phone"),
+ 78: ("appliance", "microwave"),
+ 79: ("appliance", "oven"),
+ 80: ("appliance", "toaster"),
+ 81: ("appliance", "sink"),
+ 82: ("appliance", "refrigerator"),
+ 84: ("indoor", "book"),
+ 85: ("indoor", "clock"),
+ 86: ("indoor", "vase"),
+ 87: ("indoor", "scissors"),
+ 88: ("indoor", "teddy bear"),
+ 89: ("indoor", "hair drier"),
+ 90: ("indoor", "toothbrush"),
+ 92: ("textile", "banner"),
+ 93: ("textile", "blanket"),
+ 94: ("plant", "branch"),
+ 95: ("building", "bridge"),
+ 96: ("building", "building-other"),
+ 97: ("plant", "bush"),
+ 98: ("furniture-stuff", "cabinet"),
+ 99: ("structural", "cage"),
+ 100: ("raw-material", "cardboard"),
+ 101: ("floor", "carpet"),
+ 102: ("ceiling", "ceiling-other"),
+ 103: ("ceiling", "ceiling-tile"),
+ 104: ("textile", "cloth"),
+ 105: ("textile", "clothes"),
+ 106: ("sky", "clouds"),
+ 107: ("furniture-stuff", "counter"),
+ 108: ("furniture-stuff", "cupboard"),
+ 109: ("textile", "curtain"),
+ 110: ("furniture-stuff", "desk-stuff"),
+ 111: ("ground", "dirt"),
+ 112: ("furniture-stuff", "door-stuff"),
+ 113: ("structural", "fence"),
+ 114: ("floor", "floor-marble"),
+ 115: ("floor", "floor-other"),
+ 116: ("floor", "floor-stone"),
+ 117: ("floor", "floor-tile"),
+ 118: ("floor", "floor-wood"),
+ 119: ("plant", "flower"),
+ 120: ("water", "fog"),
+ 121: ("food-stuff", "food-other"),
+ 122: ("food-stuff", "fruit"),
+ 123: ("furniture-stuff", "furniture-other"),
+ 124: ("plant", "grass"),
+ 125: ("ground", "gravel"),
+ 126: ("ground", "ground-other"),
+ 127: ("solid", "hill"),
+ 128: ("building", "house"),
+ 129: ("plant", "leaves"),
+ 130: ("furniture-stuff", "light"),
+ 131: ("textile", "mat"),
+ 132: ("raw-material", "metal"),
+ 133: ("furniture-stuff", "mirror-stuff"),
+ 134: ("plant", "moss"),
+ 135: ("solid", "mountain"),
+ 136: ("ground", "mud"),
+ 137: ("textile", "napkin"),
+ 138: ("structural", "net"),
+ 139: ("raw-material", "paper"),
+ 140: ("ground", "pavement"),
+ 141: ("textile", "pillow"),
+ 142: ("plant", "plant-other"),
+ 143: ("raw-material", "plastic"),
+ 144: ("ground", "platform"),
+ 145: ("ground", "playingfield"),
+ 146: ("structural", "railing"),
+ 147: ("ground", "railroad"),
+ 148: ("water", "river"),
+ 149: ("ground", "road"),
+ 150: ("solid", "rock"),
+ 151: ("building", "roof"),
+ 152: ("textile", "rug"),
+ 153: ("food-stuff", "salad"),
+ 154: ("ground", "sand"),
+ 155: ("water", "sea"),
+ 156: ("furniture-stuff", "shelf"),
+ 157: ("sky", "sky-other"),
+ 158: ("building", "skyscraper"),
+ 159: ("ground", "snow"),
+ 160: ("solid", "solid-other"),
+ 161: ("furniture-stuff", "stairs"),
+ 162: ("solid", "stone"),
+ 163: ("plant", "straw"),
+ 164: ("structural", "structural-other"),
+ 165: ("furniture-stuff", "table"),
+ 166: ("building", "tent"),
+ 167: ("textile", "textile-other"),
+ 168: ("textile", "towel"),
+ 169: ("plant", "tree"),
+ 170: ("food-stuff", "vegetable"),
+ 171: ("wall", "wall-brick"),
+ 172: ("wall", "wall-concrete"),
+ 173: ("wall", "wall-other"),
+ 174: ("wall", "wall-panel"),
+ 175: ("wall", "wall-stone"),
+ 176: ("wall", "wall-tile"),
+ 177: ("wall", "wall-wood"),
+ 178: ("water", "water-other"),
+ 179: ("water", "waterdrops"),
+ 180: ("window", "window-blind"),
+ 181: ("window", "window-other"),
+ 182: ("solid", "wood"),
+ 183: ("other", "other"),
+ }
+
+ SUPERCLASS_TO_ID = {
+ "unlabeled": 0,
+ "person": 1,
+ "vehicle": 2,
+ "outdoor": 3,
+ "animal": 4,
+ "accessory": 5,
+ "sports": 6,
+ "kitchen": 7,
+ "food": 8,
+ "furniture": 9,
+ "electronic": 10,
+ "appliance": 11,
+ "indoor": 12,
+ "textile": 13,
+ "plant": 14,
+ "building": 15,
+ "furniture-stuff": 16,
+ "structural": 17,
+ "raw-material": 18,
+ "floor": 19,
+ "ceiling": 20,
+ "sky": 21,
+ "ground": 22,
+ "water": 23,
+ "food-stuff": 24,
+ "wall": 25,
+ "window": 26,
+ "solid": 27,
+ "other": 28,
+ }
+
+ def __init__(self):
+ max_class = max(ConvertToCocoSuperclasses.ID_TO_SUPERCLASS_AND_NAME.keys())
+ class_to_superclass = numpy.zeros((max_class + 1,), dtype=numpy.uint8)
+ for class_id, (supclass, _) in ConvertToCocoSuperclasses.ID_TO_SUPERCLASS_AND_NAME.items():
+ class_to_superclass[class_id] = ConvertToCocoSuperclasses.SUPERCLASS_TO_ID[supclass]
+ self.class_to_superclass = class_to_superclass
+
+ def __call__(self, mask: numpy.ndarray) -> numpy.ndarray:
+ """Convert mask to superclasses.
+
+ Args:
+ mask: Densely encoded segmentation mask of shape K x H x W x 1.
+
+ Returns: Segmentation mask of shape C x H x W x 1, where C is the new set of classes.
+ """
+ classes = mask.reshape(len(mask), -1).max(axis=-1)
+ superclasses = self.class_to_superclass[classes]
+ mask = (mask > 0) * superclasses[:, None, None, None]
+
+ return expand_dense_mask(mask.max(axis=0, keepdims=True))
+
+
+class OrigCenterCrop:
+ """Returns center crop at original image resolution."""
+
+ def __call__(self, image):
+ height, width = image.shape[-2:]
+ return transforms.functional.center_crop(image, min(height, width))
+
+
+class JointRandomResizedCropwithParameters(transforms.RandomResizedCrop):
+ def __init__(
+ self,
+ size,
+ scale=(0.08, 1.0),
+ ratio=(3.0 / 4.0, 4.0 / 3.0),
+ interpolation=transforms.functional.InterpolationMode.BILINEAR,
+ ):
+ super().__init__(size, scale, ratio, interpolation)
+ self.mask_to_tensor = DenseMaskToTensor()
+ self.mask_resize = ResizeNearestExact((size, size))
+
+ def forward(self, img, masks=None):
+ """Returns parameters of the resize in addition to the crop.
+
+ Args:
+ img (PIL Image or Tensor): Image to be cropped and resized.
+
+ Returns:
+ PIL Image or Tensor: Randomly cropped and resized image.
+ """
+ params = self.get_params(img, self.scale, self.ratio)
+ img = transforms.functional.resized_crop(img, *params, self.size, self.interpolation)
+
+ for mask_key, mask in masks.items():
+ if not isinstance(mask, torch.Tensor):
+ mask = self.mask_to_tensor(mask)
+ mask = transforms.functional.crop(mask, *params)
+ mask = self.mask_resize(mask)
+ masks[mask_key] = mask
+ return img, masks, params
+
+
+class MultiCrop(object):
+ def __init__(
+ self,
+ size: int = 224,
+ input_key: str = "image",
+ teacher_key: str = "teacher",
+ student_key: str = "student",
+ global_scale: Tuple[float] = (0.8, 1.0),
+ local_scale: Tuple[float] = (0.7, 1.0),
+ ratio: Tuple[float] = (3.0 / 4.0, 4.0 / 3.0),
+ mask_keys: Optional[Tuple[str]] = None,
+ ):
+ self.ratio = ratio
+ self.teacher_key = teacher_key
+ self.student_key = student_key
+ self.global_crop = JointRandomResizedCropwithParameters(size, global_scale, ratio)
+ self.local_crop = JointRandomResizedCropwithParameters(size, local_scale, ratio)
+ self.input_key = input_key
+ self.mask_keys = tuple(mask_keys) if mask_keys is not None else tuple()
+
+ def __call__(self, data):
+ if self.input_key not in data:
+ raise ValueError(f"Wrong input key {self.input_key}")
+ img = transforms.functional.to_tensor(data[self.input_key])
+ masks = {mask_key: data[mask_key] for mask_key in self.mask_keys}
+ teacher_view, global_masks, params = self.global_crop(img, masks)
+ data[self.teacher_key] = teacher_view
+ for k, mask in global_masks.items():
+ data[f"{self.teacher_key}_{k}"] = mask
+
+ student_view, local_masks, params = self.local_crop(teacher_view, global_masks)
+ data[self.student_key] = student_view
+ for k, mask in local_masks.items():
+ data[f"{self.student_key}_{k}"] = mask
+ data["params"] = torch.Tensor(numpy.array(params))
+ return data
+
+
+class TokenizeText:
+ def __init__(self, context_length: int = 77, truncate: bool = False):
+ self.context_length = context_length
+ self.truncate = truncate
+
+ def __call__(self, texts: Union[List[str], str]):
+ # TODO: Understand if there is any performance impact in importing here.
+ try:
+ import clip
+ except ImportError:
+ raise Exception("Using clip models requires installation with extra `clip`.")
+ tokenized = clip.tokenize(texts, self.context_length, self.truncate)
+ if isinstance(texts, str):
+ # If the input is a single string, then the tokenization adds an additional dimension
+ # which we don't need.
+ tokenized = tokenized[0]
+ return tokenized
+
+
+class RandomSample:
+ """Draw a random sample from the first axis of a list or array."""
+
+ def __call__(self, tokens):
+ return random.choice(tokens)
+
+
+class IsElementOfList:
+ def __init__(self, list: List[str]):
+ self.list = set(list)
+
+ def __call__(self, key):
+ return key in self.list
+
+
+class SampleFramesUsingIndices:
+ """Sample frames form a tensor dependent on indices provided in the instance."""
+
+ def __init__(self, frame_fields: List[str], index_field: str):
+ self.frame_fields = frame_fields
+ self.index_field = index_field
+
+ def __call__(self, inputs: dict):
+ indices = inputs[self.index_field]
+ for frame_field in self.frame_fields:
+ inputs[frame_field] = inputs[frame_field][indices]
+ return inputs
+
+import numpy as np
+
+class ArrayPadder:
+ def __init__(self, target_size, dim=0, pad_value=-1, switch=None):
+ """
+ Initialize the padder with target size, dimension and padding value.
+
+ :param target_size: The target size of the specified dimension after padding.
+ :param dim: The dimension along which to pad. Defaults to 0.
+ :param pad_value: The value to use for padding. Defaults to 0.
+ """
+ self.target_size = target_size
+ self.dim = dim
+ self.pad_value = pad_value
+ self.switch = switch
+
+
+ def __call__(self, arr):
+ """
+ Pad the given array using the specified target size, dimension, and padding value.
+
+ :param arr: Numpy array to be padded.
+ :return: Padded numpy array.
+ """
+ pad_size = self.target_size - arr.shape[self.dim]
+
+ if pad_size <= 0:
+
+ if self.switch:
+ # print("arr.shape before", arr.shape)
+ return np.transpose(arr, self.switch)
+ else:
+ return arr
+
+
+ pad_width = [(0, 0) for _ in range(arr.ndim)]
+ pad_width[self.dim] = (0, pad_size)
+ results = np.pad(arr, pad_width, mode='constant', constant_values=self.pad_value)
+ # print(self.switch)
+ if self.switch:
+ # print("results.shape before", results.shape)
+ results = np.transpose(results, self.switch)
+ # print("results.shape after",results.shape)
+
+ return results
+
+ def __repr__(self):
+ """
+ Pad the given array using the specified target size, dimension, and padding value.
+
+ :param arr: Numpy array to be padded.
+ :return: Padded numpy array.
+ """
+ return f"self.target_size:{self.target_size}\n, self.dim:{self.dim}\n, self.pad_value:{self.pad_value}\n,self.switch:{self.switch}\n"
\ No newline at end of file
diff --git a/ocl/scheduling.py b/ocl/scheduling.py
new file mode 100644
index 0000000..576165a
--- /dev/null
+++ b/ocl/scheduling.py
@@ -0,0 +1,173 @@
+"""Scheduling of learning rate and hyperparameters."""
+import abc
+import math
+import warnings
+from typing import Callable
+
+from torch.optim.lr_scheduler import _LRScheduler
+
+
+def warmup_fn(step: int, warmup_steps: int) -> float:
+ """Learning rate warmup.
+
+ Maps the step to a factor for rescaling the learning rate.
+ """
+ if warmup_steps:
+ return min(1.0, step / warmup_steps)
+ else:
+ return 1.0
+
+
+def exp_decay_after_warmup_fn(
+ step: int, decay_rate: float, decay_steps: int, warmup_steps: int
+) -> float:
+ """Decay function for exponential decay with learning rate warmup.
+
+ Maps the step to a factor for rescaling the learning rate.
+ """
+ factor = warmup_fn(step, warmup_steps)
+ if step < warmup_steps:
+ return factor
+ else:
+ return factor * (decay_rate ** ((step - warmup_steps) / decay_steps))
+
+
+def exp_decay_with_warmup_fn(
+ step: int, decay_rate: float, decay_steps: int, warmup_steps: int
+) -> float:
+ """Decay function for exponential decay with learning rate warmup.
+
+ Maps the step to a factor for rescaling the learning rate.
+ """
+ factor = warmup_fn(step, warmup_steps)
+ return factor * (decay_rate ** (step / decay_steps))
+
+
+class CosineAnnealingWithWarmup(_LRScheduler):
+ """Cosine annealing with warmup."""
+
+ def __init__(
+ self,
+ optimizer,
+ T_max: int,
+ warmup_steps: int = 0,
+ eta_min: float = 0.0,
+ last_epoch: int = -1,
+ error_on_exceeding_steps: bool = True,
+ verbose: bool = False,
+ ):
+ self.T_max = T_max
+ self.warmup_steps = warmup_steps
+ self.eta_min = eta_min
+ self.error_on_exceeding_steps = error_on_exceeding_steps
+ super().__init__(optimizer, last_epoch, verbose)
+
+ def _linear_lr_warmup(self, base_lr, step_num):
+ return base_lr * ((step_num + 0.5) / self.warmup_steps)
+
+ def _cosine_annealing(self, base_lr, step_num):
+ fraction_of_steps = (step_num - self.warmup_steps) / (self.T_max - self.warmup_steps - 1)
+ return self.eta_min + 1 / 2 * (base_lr - self.eta_min) * (
+ 1 + math.cos(math.pi * fraction_of_steps)
+ )
+
+ def get_lr(self):
+ if not self._get_lr_called_within_step:
+ warnings.warn(
+ "To get the last learning rate computed by the scheduler, "
+ "please use `get_last_lr()`."
+ )
+ step_num = self.last_epoch
+
+ if step_num < self.warmup_steps:
+ # Warmup.
+ return [self._linear_lr_warmup(base_lr, step_num) for base_lr in self.base_lrs]
+ elif step_num < self.T_max:
+ # Cosine annealing.
+ return [self._cosine_annealing(base_lr, step_num) for base_lr in self.base_lrs]
+ else:
+ if self.error_on_exceeding_steps:
+ raise ValueError(
+ "Tried to step {} times. The specified number of total steps is {}".format(
+ step_num + 1, self.T_max
+ )
+ )
+ else:
+ return [self.eta_min for _ in self.base_lrs]
+
+
+HPSchedulerT = Callable[[int], float] # Type for function signatures.
+
+
+class HPScheduler(metaclass=abc.ABCMeta):
+ """Base class for scheduling of scalar hyperparameters based on the number of training steps."""
+
+ @abc.abstractmethod
+ def __call__(self, step: int) -> float:
+ """Return current value of hyperparameter based on global step."""
+ pass
+
+
+class LinearHPScheduler(HPScheduler):
+ def __init__(
+ self, end_value: float, end_step: int, start_value: float = 0.0, start_step: int = 0
+ ):
+ super().__init__()
+ if start_step > end_step:
+ raise ValueError("`start_step` needs to be smaller equal to `end_step`.")
+
+ self.start_value = start_value
+ self.end_value = end_value
+ self.start_step = start_step
+ self.end_step = end_step
+
+ def __call__(self, step: int) -> float:
+ if step < self.start_step:
+ return self.start_value
+ elif step > self.end_step:
+ return self.end_value
+ else:
+ t = step - self.start_step
+ T = self.end_step - self.start_step
+ return self.start_value + t * (self.end_value - self.start_value) / T
+
+
+class StepHPScheduler(HPScheduler):
+ def __init__(self, end_value: float, switch_step: int, start_value: float = 0.0):
+ super().__init__()
+ self.start_value = start_value
+ self.end_value = end_value
+ self.switch_step = switch_step
+
+ def __call__(self, step: int) -> float:
+ if step < self.switch_step:
+ return self.start_value
+ elif step >= self.switch_step:
+ return self.end_value
+
+
+class CosineAnnealingHPScheduler(HPScheduler):
+ """Cosine annealing."""
+
+ def __init__(self, start_value: float, end_value: float, start_step: int, end_step: int):
+ super().__init__()
+ assert start_value >= end_value
+ assert start_step <= end_step
+ self.start_value = start_value
+ self.end_value = end_value
+ self.start_step = start_step
+ self.end_step = end_step
+
+ def __call__(self, step: int) -> float:
+
+ if step < self.start_step:
+ value = self.start_value
+ elif step >= self.end_step:
+ value = self.end_value
+ else:
+ a = 0.5 * (self.start_value - self.end_value)
+ b = 0.5 * (self.start_value + self.end_value)
+ progress = (step - self.start_step) / (self.end_step - self.start_step)
+ value = a * math.cos(math.pi * progress) + b
+
+ return value
diff --git a/ocl/trees.py b/ocl/trees.py
new file mode 100644
index 0000000..6afb07b
--- /dev/null
+++ b/ocl/trees.py
@@ -0,0 +1,36 @@
+import copy
+import dataclasses
+from collections import OrderedDict, abc
+from typing import Any, Callable, Dict, List, Mapping, Sequence, Tuple, Union
+
+import torch
+
+Tree = Union[Dict, List, Tuple]
+
+def get_tree_element(d: Tree, path: List[str]):
+ """Get element of a tree."""
+ next_element = d
+
+ for next_element_name in path:
+ if isinstance(next_element, abc.Mapping) and next_element_name in next_element:
+ next_element = next_element[next_element_name]
+ elif hasattr(next_element, next_element_name):
+ next_element = getattr(next_element, next_element_name)
+ elif isinstance(next_element, (list, tuple)) and next_element_name.isnumeric():
+ next_element = next_element[int(next_element_name)]
+ else:
+ try:
+ next_element = getattr(next_element, next_element_name)
+ except AttributeError:
+ msg = f"Trying to access path {'.'.join(path)}, "
+ if isinstance(next_element, abc.Mapping):
+ msg += f"but element {next_element_name} is not among keys {next_element.keys()}"
+ elif isinstance(next_element, (list, tuple)):
+ msg += f"but cannot index into list with {next_element_name}"
+ else:
+ msg += (
+ f"but element {next_element_name} cannot be used to access attribute of "
+ f"object of type {type(next_element)}"
+ )
+ raise ValueError(msg)
+ return next_element
\ No newline at end of file
diff --git a/ocl/utils/__init__.py b/ocl/utils/__init__.py
new file mode 100644
index 0000000..fde6730
--- /dev/null
+++ b/ocl/utils/__init__.py
@@ -0,0 +1,5 @@
+# We added this here to avoid issues in rebasing.
+# Long term the imports should be updated.
+from ocl.utils.windows import JoinWindows
+
+__all__ = ["JoinWindows"]
diff --git a/ocl/utils/annealing.py b/ocl/utils/annealing.py
new file mode 100644
index 0000000..b6bd787
--- /dev/null
+++ b/ocl/utils/annealing.py
@@ -0,0 +1,18 @@
+import math
+def cosine_anneal_factory(start_value, final_value, start_step, final_step):
+ def cosine_anneal(step):
+ assert start_value >= final_value
+ assert start_step <= final_step
+
+ if step < start_step:
+ value = start_value
+ elif step >= final_step:
+ value = final_value
+ else:
+ a = 0.5 * (start_value - final_value)
+ b = 0.5 * (start_value + final_value)
+ progress = (step - start_step) / (final_step - start_step)
+ value = a * math.cos(math.pi * progress) + b
+
+ return value
+ return cosine_anneal
\ No newline at end of file
diff --git a/ocl/utils/bboxes.py b/ocl/utils/bboxes.py
new file mode 100644
index 0000000..c8f134d
--- /dev/null
+++ b/ocl/utils/bboxes.py
@@ -0,0 +1,14 @@
+"""Utilities for handling bboxes."""
+import torch
+
+
+def box_cxcywh_to_xyxy(x):
+ x_c, y_c, w, h = x.unbind(-1)
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
+ return torch.stack(b, dim=-1)
+
+
+def box_xyxy_to_cxcywh(x):
+ x0, y0, x1, y1 = x.unbind(-1)
+ b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
+ return torch.stack(b, dim=-1)
diff --git a/ocl/utils/logging.py b/ocl/utils/logging.py
new file mode 100644
index 0000000..77cc536
--- /dev/null
+++ b/ocl/utils/logging.py
@@ -0,0 +1,252 @@
+import os
+import signal
+import traceback
+from tempfile import TemporaryDirectory
+
+import numpy as np
+import torch
+import yaml
+from mlflow.client import MlflowClient
+from pytorch_lightning.callbacks import Callback
+from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
+from pytorch_lightning.loggers import MLFlowLogger
+from pytorch_lightning.loggers.base import rank_zero_experiment
+from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
+from pytorch_lightning.utilities.model_summary import ModelSummary
+from torch.utils.tensorboard._convert_np import make_np
+from torch.utils.tensorboard._utils import _prepare_video, convert_to_HWC, figure_to_image
+from torch.utils.tensorboard.summary import _calc_scale_factor
+
+from ocl.utils.trees import get_tree_element
+
+
+def prepare_video_tensor(tensor):
+ tensor = make_np(tensor)
+ tensor = _prepare_video(tensor)
+ # If user passes in uint8, then we don't need to rescale by 255
+ scale_factor = _calc_scale_factor(tensor)
+ tensor = tensor.astype(np.float32)
+ tensor = (tensor * scale_factor).astype(np.uint8)
+ return tensor
+
+
+def write_video_tensor(prefix, tensor, fps):
+ try:
+ import moviepy # noqa: F401
+ except ImportError:
+ print("add_video needs package moviepy")
+ return
+ try:
+ from moviepy import editor as mpy
+ except ImportError:
+ print(
+ "moviepy is installed, but can't import moviepy.editor.",
+ "Some packages could be missing [imageio, requests]",
+ )
+ return
+
+ # encode sequence of images into gif string
+ clip = mpy.ImageSequenceClip(list(tensor), fps=fps)
+ filename = prefix + ".gif"
+
+ try: # newer version of moviepy use logger instead of progress_bar argument.
+ clip.write_gif(filename, verbose=False, logger=None)
+ except TypeError:
+ try: # older version of moviepy does not support progress_bar argument.
+ clip.write_gif(filename, verbose=False, progress_bar=False)
+ except TypeError:
+ clip.write_gif(filename, verbose=False)
+
+ return filename
+
+
+def prepare_image_tensor(tensor, dataformats="NCHW"):
+ tensor = make_np(tensor)
+ tensor = convert_to_HWC(tensor, dataformats)
+ # Do not assume that user passes in values in [0, 255], use data type to detect
+ scale_factor = _calc_scale_factor(tensor)
+ tensor = tensor.astype(np.float32)
+ tensor = (tensor * scale_factor).astype(np.uint8)
+ return tensor
+
+
+def write_image_tensor(prefix: str, tensor: np.ndarray):
+ from PIL import Image
+
+ image = Image.fromarray(tensor)
+ filename = prefix + ".png"
+ image.save(filename, format="png")
+ return filename
+
+
+class ExtendedMLflowExperiment:
+ """MLflow experiment made to mimic tensorboard experiments."""
+
+ def __init__(self, mlflow_client: MlflowClient, run_id: str):
+ self._mlflow_client = mlflow_client
+ self._run_id = run_id
+ self._tempdir = TemporaryDirectory()
+
+ def _get_tmp_prefix_for_step(self, step: int):
+ return os.path.join(self._tempdir.name, f"{step:07d}")
+
+ def add_video(self, vid_tensor, fps: int, tag: str, global_step: int):
+ path = tag # TF paths are typically split using "/"
+ filename = write_video_tensor(
+ self._get_tmp_prefix_for_step(global_step), prepare_video_tensor(vid_tensor), fps
+ )
+ self._mlflow_client.log_artifact(self._run_id, filename, path)
+ os.remove(filename)
+
+ def add_image(self, img_tensor: torch.Tensor, dataformats: str, tag: str, global_step: int):
+ path = tag
+ filename = write_image_tensor(
+ self._get_tmp_prefix_for_step(global_step),
+ prepare_image_tensor(img_tensor, dataformats=dataformats),
+ )
+ self._mlflow_client.log_artifact(self._run_id, filename, path)
+ os.remove(filename)
+
+ def add_images(self, img_tensor, dataformats: str, tag: str, global_step: int):
+ # Internally works by having an additional N dimension in `dataformats`.
+ self.add_image(img_tensor, dataformats, tag, global_step)
+
+ def add_figure(self, figure, close: bool, tag: str, global_step: int):
+ if isinstance(figure, list):
+ self.add_image(
+ figure_to_image(figure, close),
+ dataformats="NCHW",
+ tag=tag,
+ global_step=global_step,
+ )
+ else:
+ self.add_image(
+ figure_to_image(figure, close),
+ dataformats="CHW",
+ tag=tag,
+ global_step=global_step,
+ )
+
+ def __getattr__(self, name):
+ """Fallback to mlflow client for missing attributes.
+
+ Fallback to make the experiment object still behave like the regular MLflow client. While
+ this is suboptimal, it does allow us to save a lot of handcrafted code by relying on
+ inheritance and pytorch lightings implementation of the MLflow logger.
+ """
+ return getattr(self._mlflow_client, name)
+
+
+class ExtendedMLFlowLogger(MLFlowLogger):
+ @property # type: ignore[misc]
+ @rank_zero_experiment
+ def experiment(self) -> ExtendedMLflowExperiment:
+ return ExtendedMLflowExperiment(super().experiment, self._run_id)
+
+ @rank_zero_only
+ def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint):
+ self.experiment.log_artifact(
+ self._run_id, checkpoint_callback.best_model_path, "checkpoints"
+ )
+
+ @rank_zero_only
+ def log_artifact(self, local_path, artifact_path=None):
+ self.experiment.log_artifact(self._run_id, local_path, artifact_path=artifact_path)
+
+ @rank_zero_only
+ def log_artifacts(self, local_path, artifact_path=None):
+ self.experiment.log_artifacts(self._run_id, local_path, artifact_path=artifact_path)
+
+
+class LogHydraConfigCallback(Callback):
+ def __init__(self, hydra_output_subdir: str, additional_paths: None, skip_overrides=False):
+ self.hydra_output_subdir = hydra_output_subdir
+ self.additional_paths = additional_paths
+ self.skip_overrides = skip_overrides
+
+ def _parse_overrides(self):
+ with open(os.path.join(self.hydra_output_subdir, "overrides.yaml"), "r") as f:
+ overrides = yaml.safe_load(f)
+
+ output = {}
+ for override in overrides:
+ fragments = override.split("=")
+ if len(fragments) == 2:
+ if override.startswith("+"):
+ fragments[0] = fragments[0][1:]
+
+ output[fragments[0]] = fragments[1]
+ return output
+
+ def _parse_additional_paths(self):
+ with open(os.path.join(self.hydra_output_subdir, "config.yaml"), "r") as f:
+ config = yaml.safe_load(f)
+
+ outputs = {}
+ if isinstance(self.additional_paths, dict):
+ for output_path, input_path in self.additional_paths.items():
+ outputs[output_path] = get_tree_element(config, input_path.split("."))
+ elif isinstance(self.additional_paths, list):
+ for additional_path in self.additional_paths:
+ outputs[additional_path] = get_tree_element(config, additional_path.split("."))
+ else:
+ raise ValueError("additional_paths of unsupported format")
+
+ return outputs
+
+ @rank_zero_only
+ def on_train_start(self, trainer, pl_module):
+ # Log all hydra config files.
+ trainer.logger.log_artifacts(self.hydra_output_subdir, "config")
+ if not self.skip_overrides:
+ trainer.logger.log_hyperparams(self._parse_overrides())
+ if self.additional_paths:
+ trainer.logger.log_hyperparams(self._parse_additional_paths())
+
+ @rank_zero_only
+ def on_exception(self, trainer, pl_module, exception):
+ del pl_module
+ logger = trainer.logger
+ with TemporaryDirectory() as d:
+ filename = os.path.join(d, "exception.txt")
+ with open(filename, "w") as f:
+ traceback.print_exc(file=f)
+ f.flush()
+ trainer.logger.log_artifact(filename)
+ os.remove(filename)
+ if logger.experiment.get_run(logger.run_id):
+ if isinstance(exception, KeyboardInterrupt):
+ logger.experiment.set_terminated(logger.run_id, status="KILLED")
+ else:
+ logger.experiment.set_terminated(logger.run_id, status="FAILED")
+
+ def on_fit_start(self, trainer, pl_module):
+ del pl_module
+ # Register our own signal handler to set run to terminated.
+ previous_sigterm_handler = signal.getsignal(signal.SIGTERM)
+
+ def handler(signum, frame):
+ rank_zero_info("Handling SIGTERM")
+ logger = trainer.logger
+ logger.experiment.set_terminated(logger.run_id, status="KILLED")
+ logger.save()
+ if previous_sigterm_handler in [None, signal.SIG_DFL]:
+ # Not set up by python or default behaviour.
+ signal.signal(signal.SIGTERM, signal.SIG_DFL)
+ signal.raise_signal(signal.SIGTERM)
+ elif previous_sigterm_handler != signal.SIG_IGN:
+ # If none of the above must be callable.
+ previous_sigterm_handler()
+
+ signal.signal(signal.SIGTERM, handler)
+
+
+class LogModelSummaryCallback(Callback):
+ @rank_zero_only
+ def on_fit_start(self, trainer, pl_module):
+ with TemporaryDirectory() as d:
+ filename = os.path.join(d, "model_summary.txt")
+ with open(filename, "w") as f:
+ f.write(str(ModelSummary(pl_module, max_depth=-1)))
+ trainer.logger.log_artifact(filename)
+ os.remove(filename)
diff --git a/ocl/utils/masking.py b/ocl/utils/masking.py
new file mode 100644
index 0000000..cebf8ef
--- /dev/null
+++ b/ocl/utils/masking.py
@@ -0,0 +1,75 @@
+"""Utilities related to masking."""
+import math
+from typing import Optional
+
+import torch
+from torch import nn
+
+from ocl.utils.routing import RoutableMixin
+
+
+class CreateSlotMask(nn.Module, RoutableMixin):
+ """Module intended to create a mask that marks empty slots.
+
+ Module takes a tensor holding the number of slots per batch entry, and returns a binary mask of
+ shape (batch_size, max_slots) where entries exceeding the number of slots are masked out.
+ """
+
+ def __init__(self, max_slots: int, n_slots_path: str):
+ nn.Module.__init__(self)
+ RoutableMixin.__init__(self, {"n_slots": n_slots_path})
+ self.max_slots = max_slots
+
+ @RoutableMixin.route
+ def forward(self, n_slots: torch.Tensor) -> torch.Tensor:
+ (batch_size,) = n_slots.shape
+
+ # Create mask of shape B x K where the first n_slots entries per-row are false, the rest true
+ indices = torch.arange(self.max_slots, device=n_slots.device)
+ masks = indices.unsqueeze(0).expand(batch_size, -1) >= n_slots.unsqueeze(1)
+
+ return masks
+
+
+class CreateRandomMaskPatterns(nn.Module, RoutableMixin):
+ """Create random masks.
+
+ Useful for showcasing behavior of metrics.
+ """
+
+ def __init__(
+ self, pattern: str, masks_path: str, n_slots: Optional[int] = None, n_cols: int = 2
+ ):
+ nn.Module.__init__(self)
+ RoutableMixin.__init__(self, {"masks": masks_path})
+ if pattern not in ("random", "blocks"):
+ raise ValueError(f"Unknown pattern {pattern}")
+ self.pattern = pattern
+ self.n_slots = n_slots
+ self.n_cols = n_cols
+
+ @RoutableMixin.route
+ def forward(self, masks: torch.Tensor) -> torch.Tensor:
+ if self.pattern == "random":
+ rand_mask = torch.rand_like(masks)
+ return rand_mask / rand_mask.sum(1, keepdim=True)
+ elif self.pattern == "blocks":
+ n_slots = masks.shape[1] if self.n_slots is None else self.n_slots
+ height, width = masks.shape[-2:]
+ new_masks = torch.zeros(
+ len(masks), n_slots, height, width, device=masks.device, dtype=masks.dtype
+ )
+ blocks_per_col = int(n_slots // self.n_cols)
+ remainder = n_slots - (blocks_per_col * self.n_cols)
+ slot = 0
+ for col in range(self.n_cols):
+ rows = blocks_per_col if col < self.n_cols - 1 else blocks_per_col + remainder
+ for row in range(rows):
+ block_width = math.ceil(width / self.n_cols)
+ block_height = math.ceil(height / rows)
+ x = col * block_width
+ y = row * block_height
+ new_masks[:, slot, y : y + block_height, x : x + block_width] = 1
+ slot += 1
+ assert torch.allclose(new_masks.sum(1), torch.ones_like(masks[:, 0]))
+ return new_masks
diff --git a/ocl/utils/resizing.py b/ocl/utils/resizing.py
new file mode 100644
index 0000000..f5eb6cb
--- /dev/null
+++ b/ocl/utils/resizing.py
@@ -0,0 +1,147 @@
+"""Utilities related to resizing of tensors."""
+import math
+from typing import Optional, Tuple, Union
+
+import torch
+from torch import nn
+
+from ocl.utils.routing import RoutableMixin
+
+
+class Resize(nn.Module, RoutableMixin):
+ """Module resizing tensors."""
+
+ MODES = {"nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"}
+
+ def __init__(
+ self,
+ input_path: str,
+ size: Optional[Union[int, Tuple[int, int]]] = None,
+ take_size_from: Optional[str] = None,
+ resize_mode: str = "bilinear",
+ patch_mode: bool = False,
+ channels_last: bool = False,
+ ):
+ nn.Module.__init__(self)
+ RoutableMixin.__init__(self, {"tensor": input_path, "size_tensor": take_size_from})
+
+ if size is not None and take_size_from is not None:
+ raise ValueError("`size` and `take_size_from` can not be set at the same time")
+ self.size = size
+
+ if resize_mode not in Resize.MODES:
+ raise ValueError(f"`mode` must be one of {Resize.MODES}")
+ self.resize_mode = resize_mode
+ self.patch_mode = patch_mode
+ self.channels_last = channels_last
+ self.expected_dims = 3 if patch_mode else 4
+
+ @RoutableMixin.route
+ def forward(
+ self, tensor: torch.Tensor, size_tensor: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ """Resize tensor.
+
+ Args:
+ tensor: Tensor to resize. If `patch_mode=False`, assumed to be of shape (..., C, H, W).
+ If `patch_mode=True`, assumed to be of shape (..., C, P), where P is the number of
+ patches. Patches are assumed to be viewable as a perfect square image. If
+ `channels_last=True`, channel dimension is assumed to be the last dimension instead.
+ size_tensor: Tensor which size to resize to. If tensor has <=2 dimensions and the last
+ dimension of this tensor has length 2, the two entries are taken as height and width.
+ Otherwise, the size of the last two dimensions of this tensor are used as height
+ and width.
+
+ Returns: Tensor of shape (..., C, H, W), where height and width are either specified by
+ `size` or `size_tensor`.
+ """
+ dims_to_flatten = tensor.ndim - self.expected_dims
+ if dims_to_flatten > 0:
+ flattened_dims = tensor.shape[: dims_to_flatten + 1]
+ tensor = tensor.flatten(0, dims_to_flatten)
+ elif dims_to_flatten < 0:
+ raise ValueError(
+ f"Tensor needs at least {self.expected_dims} dimensions, but only has {tensor.ndim}"
+ )
+
+ if self.patch_mode:
+ if self.channels_last:
+ tensor = tensor.transpose(-2, -1)
+ n_channels, n_patches = tensor.shape[-2:]
+ patch_size_float = math.sqrt(n_patches)
+ patch_size = int(math.sqrt(n_patches))
+ if patch_size_float != patch_size:
+ raise ValueError(
+ f"The number of patches needs to be a perfect square, but is {n_patches}."
+ )
+ tensor = tensor.view(-1, n_channels, patch_size, patch_size)
+ else:
+ if self.channels_last:
+ tensor = tensor.permute(0, 3, 1, 2)
+
+ if self.size is None:
+ if size_tensor is None:
+ raise ValueError("`size` is `None` but no `size_tensor` was passed.")
+ if size_tensor.ndim <= 2 and size_tensor.shape[-1] == 2:
+ height, width = size_tensor.unbind(-1)
+ height = torch.atleast_1d(height)[0].squeeze().detach().cpu()
+ width = torch.atleast_1d(width)[0].squeeze().detach().cpu()
+ size = (int(height), int(width))
+ else:
+ size = size_tensor.shape[-2:]
+ else:
+ size = self.size
+
+ tensor = torch.nn.functional.interpolate(
+ tensor,
+ size=size,
+ mode=self.resize_mode,
+ )
+
+ if dims_to_flatten > 0:
+ tensor = tensor.unflatten(0, flattened_dims)
+
+ return tensor
+
+
+def resize_patches_to_image(
+ patches: torch.Tensor,
+ size: Optional[int] = None,
+ scale_factor: Optional[float] = None,
+ resize_mode: str = "bilinear",
+) -> torch.Tensor:
+ """Convert and resize a tensor of patches to image shape.
+
+ This method requires that the patches can be converted to a square image.
+
+ Args:
+ patches: Patches to be converted of shape (..., C, P), where C is the number of channels and
+ P the number of patches.
+ size: Image size to resize to.
+ scale_factor: Scale factor by which to resize the patches. Can be specified alternatively to
+ `size`.
+ resize_mode: Method to resize with. Valid options are "nearest", "nearest-exact", "bilinear",
+ "bicubic".
+
+ Returns: Tensor of shape (..., C, S, S) where S is the image size.
+ """
+ has_size = size is None
+ has_scale = scale_factor is None
+ if has_size == has_scale:
+ raise ValueError("Exactly one of `size` or `scale_factor` must be specified.")
+
+ n_channels = patches.shape[-2]
+ n_patches = patches.shape[-1]
+ patch_size_float = math.sqrt(n_patches)
+ patch_size = int(math.sqrt(n_patches))
+ if patch_size_float != patch_size:
+ raise ValueError("The number of patches needs to be a perfect square.")
+
+ image = torch.nn.functional.interpolate(
+ patches.view(-1, n_channels, patch_size, patch_size),
+ size=size,
+ scale_factor=scale_factor,
+ mode=resize_mode,
+ )
+
+ return image.view(*patches.shape[:-1], image.shape[-2], image.shape[-1])
diff --git a/ocl/utils/routing.py b/ocl/utils/routing.py
new file mode 100644
index 0000000..baf2f3d
--- /dev/null
+++ b/ocl/utils/routing.py
@@ -0,0 +1,251 @@
+"""Utility function related to routing of information.
+
+These utility functions allow dynamical routing between modules and allow the specification of
+complex models using config alone.
+"""
+from __future__ import annotations
+
+import functools
+import inspect
+from typing import Any, Dict, List, Mapping, Optional, Union
+
+import torch
+from torch import nn
+
+import ocl.utils.trees as tree_utils
+
+
+class RoutableMixin:
+ """Mixin class that allows to connect any element of a (nested) dict with a module input."""
+
+ def __init__(self, input_mapping: Mapping[str, Optional[str]]):
+ self.input_mapping = {
+ key: value.split(".") for key, value in input_mapping.items() if value is not None
+ }
+
+ def _route(method, filter_parameters=True):
+ """Pass arguments to a function based on the mapping defined in `self.input_mapping`.
+
+ This method supports both filtering for parameters that match the arguments of the wrapped
+ method and passing all arguments defined in `input_mapping`. If a non-optional argument is
+ missing this will raise an exception. Additional arguments can also be passed to the method
+ to override entries in the input dict. Non-keyword arguments are always directly passed to
+ the method.
+
+ Args:
+ method: The method to pass the arguments to.
+ filter_parameters: Only pass arguments to wrapped method that match the methods
+ signature. This is practical if different methods require different types of input.
+
+ """
+ # Run inspection here to reduce compute time when calling method.
+ signature = inspect.signature(method)
+ valid_parameters = list(signature.parameters) # Returns the parameter names.
+ valid_parameters = valid_parameters[1:] # Discard "self".
+ # Keep track of default parameters. For these we should not fail if they are not in
+ # the input dict.
+ with_defaults = [
+ name
+ for name, param in signature.parameters.items()
+ if param.default is not inspect.Parameter.empty
+ ]
+
+ @functools.wraps(method)
+ def method_with_routing(self, *args, inputs=None, **kwargs):
+ if not inputs:
+ inputs = {}
+ if self.input_mapping:
+ if not inputs: # Empty dict.
+ inputs = kwargs
+
+ routed_inputs = {}
+ for input_field, input_path in self.input_mapping.items():
+ if filter_parameters and input_field not in valid_parameters:
+ # Skip parameters that are not the function signature.
+ continue
+ if input_field in kwargs.keys():
+ # Skip parameters that are directly provided as kwargs.
+ continue
+ try:
+ element = tree_utils.get_tree_element(inputs, input_path)
+ routed_inputs[input_field] = element
+ except ValueError as e:
+ if input_field in with_defaults:
+ continue
+ else:
+ raise e
+ # Support for additional parameters passed via keyword arguments.
+ # TODO(hornmax): This is not ideal as it mixes routing args from the input dict
+ # and explicitly passed kwargs and thus could lead to collisions.
+ for name, element in kwargs.items():
+ if filter_parameters and name not in valid_parameters:
+ continue
+ else:
+ routed_inputs[name] = element
+ return method(self, *args, **routed_inputs)
+ else:
+ return method(self, *args, **kwargs)
+
+ return method_with_routing
+
+ # This is needed in order to allow the decorator to be used in child classes. The documentation
+ # looks a bit hacky but I didn't find an alternative approach on how to do it.
+ route = staticmethod(functools.partial(_route, filter_parameters=True))
+ route.__doc__ = (
+ """Route input arguments according to input_mapping and filter non-matching arguments."""
+ )
+ route_unfiltered = staticmethod(functools.partial(_route, filter_parameters=False))
+ route_unfiltered.__doc__ = """Route all input arguments according to input_mapping."""
+
+
+class DataRouter(nn.Module, RoutableMixin):
+ """Data router for modules that don't support the RoutableMixin.
+
+ This allows the usage of modules without RoutableMixin support in the dynamic information flow
+ pattern of the code.
+ """
+
+ def __init__(self, module: nn.Module, input_mapping: Mapping[str, str]):
+ nn.Module.__init__(self)
+ RoutableMixin.__init__(self, input_mapping)
+ self.module = module
+ self._cached_valid_parameters = None
+
+ @RoutableMixin.route_unfiltered
+ def forward(self, *args, **kwargs):
+ # We need to filter parameters at runtime as we cannot know them prior to initialization.
+ if not self._cached_valid_parameters:
+ try:
+ signature = inspect.signature(self.module.forward)
+ except AttributeError:
+ if callable(self.module):
+ signature = inspect.signature(self.module.__call__)
+ else:
+ signature = inspect.signature(self.module)
+
+ self._cached_valid_parameters = list(signature.parameters)
+
+ kwargs = {
+ name: param for name, param in kwargs.items() if name in self._cached_valid_parameters
+ }
+ return self.module(*args, **kwargs)
+
+
+class Combined(nn.ModuleDict):
+ """Module to combine multiple modules and store their outputs.
+
+ A combined module groups together multiple model components and allows them to access any
+ information that was returned in processing steps prior to their own application.
+
+ It functions similarly to `nn.ModuleDict` yet for modules of type `RoutableMixin` and
+ additionally implements a forward routine which will return a dict of the outputs of the
+ submodules.
+
+ """
+
+ def __init__(self, modules: Dict[str, Union[RoutableMixin, Combined, Recurrent]]):
+ super().__init__(modules)
+
+ def forward(self, inputs: Dict[str, Any]):
+ # The combined module does not know where it is positioned and thus also does not know in
+ # which sub-path results should be written. As we want different modules of a combined
+ # module to be able access previous outputs using their global path in the dictionary, we
+ # need to somehow keep track of the nesting level and then directly write results into the
+ # input dict at the right path. The prefix variable keeps track of the nesting level.
+ prefix: List[str]
+ if "prefix" in inputs.keys():
+ prefix = inputs["prefix"]
+ else:
+ prefix = []
+ inputs["prefix"] = prefix
+
+ outputs = tree_utils.get_tree_element(inputs, prefix)
+ for name, module in self.items():
+ # Update prefix state such that nested calls of combined return dict in the correct
+ # location.
+ prefix.append(name)
+ outputs[name] = {}
+ # If module is a Combined module, it will return the same dict as set above. If not the
+ # dict will be overwritten with the output of the module.
+ outputs[name] = module(inputs=inputs)
+ # Remove last component of prefix after execution.
+ prefix.pop()
+ return outputs
+
+
+class Recurrent(nn.Module):
+ """Module to apply another module in a recurrent fashion over a axis.
+
+ This module takes a set of input tensors and applies a module recurrent over them. The output
+ of the previous iteration is kept in the `previous_output` key of input dict and thus can be
+ accessed using data routing. After applying the module to the input slices, the outputs are
+ stacked along the same axis as the inputs where split.
+
+ Args:
+ module: The module that should be applied recurrently along input tensors.
+ inputs_to_split: List of paths that should be split for recurrent application.
+ initial_input_mapping: Mapping that constructs the first `previous_output` element. If
+ `previous_output` should just be a tensor, use a mapping of the format
+ `{"": "input_path"}`.
+ split_axis: Axis along which to split the tensors defined by inputs_to_split.
+ chunk_size: The size of each slice, when set to 1, the slice dimension is squeezed prior to
+ passing to the module.
+
+ """
+
+ def __init__(
+ self,
+ module,
+ inputs_to_split: List[str],
+ initial_input_mapping: Dict[str, str],
+ split_axis: int = 1,
+ chunk_size: int = 1,
+ ):
+ super().__init__()
+ self.module = module
+ self.inputs_to_split = [path.split(".") for path in inputs_to_split]
+ self.initial_input_mapping = {
+ output: input.split(".") for output, input in initial_input_mapping.items()
+ }
+ self.split_axis = split_axis
+ self.chunk_size = chunk_size
+
+ def _build_initial_dict(self, inputs):
+ # This allows us to bing the initial input and previous_output into a similar format.
+ output_dict = {}
+ for output_path, input_path in self.initial_input_mapping.items():
+ source = tree_utils.get_tree_element(inputs, input_path)
+ if output_path == "":
+ # Just the object itself, no dict nesting.
+ return source
+
+ output_path = output_path.split(".")
+ cur_search = output_dict
+ for path_part in output_path[:-1]:
+ # Iterate along path and create nodes that do not exist yet.
+ try:
+ # Get element prior to last.
+ cur_search = tree_utils.get_tree_element(cur_search, [path_part])
+ except ValueError:
+ # Element does not yet exist.
+ cur_search[path_part] = {}
+ cur_search = cur_search[path_part]
+
+ cur_search[output_path[-1]] = source
+ return output_dict
+
+ def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
+ # TODO: Come up with a better way of handling the initial input without putting restrictions
+ # on modules being run recurrently.
+ outputs = [self._build_initial_dict(inputs)]
+ for split_dict in tree_utils.split_tree(
+ inputs, self.inputs_to_split, self.split_axis, self.chunk_size
+ ):
+ split_dict["previous_output"] = outputs[-1]
+ outputs.append(self.module(inputs=split_dict))
+
+ # TODO: When chunk size is larger than 1 then this should be cat and not stack. Otherwise an
+ # additional axis would be added. Evtl. this should be configurable.
+ stack_fn = functools.partial(torch.stack, dim=self.split_axis)
+ # Ignore initial input.
+ return tree_utils.reduce_tree(outputs[1:], stack_fn)
diff --git a/ocl/utils/trees.py b/ocl/utils/trees.py
new file mode 100644
index 0000000..fff0f8c
--- /dev/null
+++ b/ocl/utils/trees.py
@@ -0,0 +1,158 @@
+"""Utilities for working with our own version of PyTrees which focus on torch tensors.
+
+PyTrees are any nested structure of dictionaries, lists, tuples, namedtuples or dataclasses.
+"""
+import copy
+import dataclasses
+from collections import OrderedDict, abc
+from typing import Any, Callable, Dict, List, Mapping, Sequence, Tuple, Union
+
+import torch
+
+Tree = Union[Dict, List, Tuple]
+
+
+def is_tensor_or_module(t: Any):
+ """Check if input is a torch.Tensor or a torch.nn.Module."""
+ return isinstance(t, (torch.Tensor, torch.nn.Module))
+
+
+def is_namedtuple(obj) -> bool:
+ """Check if input is a named tuple."""
+ return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")
+
+
+def get_tree_element(d: Tree, path: List[str]):
+ """Get element of a tree."""
+ next_element = d
+
+ for next_element_name in path:
+ if isinstance(next_element, abc.Mapping) and next_element_name in next_element:
+ next_element = next_element[next_element_name]
+ elif hasattr(next_element, next_element_name):
+ next_element = getattr(next_element, next_element_name)
+ elif isinstance(next_element, (list, tuple)) and next_element_name.isnumeric():
+ next_element = next_element[int(next_element_name)]
+ else:
+ try:
+ next_element = getattr(next_element, next_element_name)
+ except AttributeError:
+ msg = f"Trying to access path {'.'.join(path)}, "
+ if isinstance(next_element, abc.Mapping):
+ msg += f"but element {next_element_name} is not among keys {next_element.keys()}"
+ elif isinstance(next_element, (list, tuple)):
+ msg += f"but cannot index into list with {next_element_name}"
+ else:
+ msg += (
+ f"but element {next_element_name} cannot be used to access attribute of "
+ f"object of type {type(next_element)}"
+ )
+ raise ValueError(msg)
+ return next_element
+
+
+def _build_walk_path(previous_element, new_element):
+ return previous_element + [new_element]
+
+
+def walk_tree_with_paths(next_element, path=None, instance_check=is_tensor_or_module):
+ """Walk over all tensors + modules and their paths in a nested structure.
+
+ This could lead to an infinite loop.
+ """
+ if path is None:
+ path = []
+
+ if instance_check(next_element):
+ yield path, next_element
+ elif isinstance(next_element, str):
+ # Special handling for strings, as even a single element slice is a sequence. This leads to
+ # infinite nesting.
+ pass
+ elif isinstance(next_element, (dict, Mapping)):
+ for key, value in next_element.items():
+ yield from walk_tree_with_paths(
+ value, path=_build_walk_path(path, key), instance_check=instance_check
+ )
+ elif dataclasses.is_dataclass(next_element):
+ for field in dataclasses.fields(next_element):
+ yield from walk_tree_with_paths(
+ getattr(next_element, field.name),
+ path=_build_walk_path(path, field.name),
+ instance_check=instance_check,
+ )
+ elif is_namedtuple(next_element):
+ for field_name in next_element._fields:
+ yield from walk_tree_with_paths(
+ getattr(next_element, field_name),
+ path=_build_walk_path(path, field_name),
+ instance_check=instance_check,
+ )
+ elif isinstance(next_element, (List, Sequence, tuple)):
+ for index, el in enumerate(next_element):
+ yield from walk_tree_with_paths(
+ el, path=_build_walk_path(path, index), instance_check=instance_check
+ )
+
+
+def reduce_tree(outputs: List[Dict[str, Any]], fn: Callable[[List[torch.Tensor]], torch.Tensor]):
+ """Apply reduction function to a list of nested dicts.
+
+ This only considers tensors at the moment, for other data types are simply copied from the first
+ element.
+ """
+ id_to_reduced_tensor = {}
+ for path, tensor in walk_tree_with_paths(outputs[0]):
+ stacked_tensor = fn([tensor] + [get_tree_element(output, path) for output in outputs[1:]])
+ id_to_reduced_tensor[id(tensor)] = stacked_tensor
+
+ # Replace all tensors with their stacked versions.
+ return copy.deepcopy(outputs[0], memo=id_to_reduced_tensor)
+
+
+def map_tree(d: Tree, fn: Callable[[torch.Tensor], torch.Tensor]):
+ """Apply a function to each element of a tree.
+
+ This only considers tensors at the moment, for other data types are simply copied from the first
+ element.
+ """
+ id_to_mapped_tensor = {}
+ for _, tensor in walk_tree_with_paths(d):
+ mapped_tensor = fn(tensor)
+ id_to_mapped_tensor[id(tensor)] = mapped_tensor
+
+ # Replace all tensors with their stacked versions.
+ return copy.deepcopy(d, memo=id_to_mapped_tensor)
+
+
+def split_tree(d: Tree, split_paths: List[List[str]], split_axis: int, chunk_size: int):
+ # We essentially need a deep copy of the input dict that we then update with splitted
+ # references. To avoid copies of tensors and thus memory duplication we want to use shallow
+ # copies for tensors instead. We do this by defining the memo parameter used in deepcopy for
+ # all tensors in the dict. This way deepcopy thinks that these where already copied and uses
+ # the provided objects instead. We can further use this trick to replace the original
+ # tensors with splitted counterparts when running deepcopy.
+
+ # Create memo containing all tensors to avoid data duplication.
+ memo = {id(tensor): tensor for path, tensor in walk_tree_with_paths(d)}
+
+ # Gather tensors that should be replaced and note their id.
+ tensors_to_split = [get_tree_element(d, path) for path in split_paths]
+ splitted_memos = OrderedDict(
+ (id(tensor), torch.split(tensor, chunk_size, dim=split_axis)) for tensor in tensors_to_split
+ )
+
+ for tensor_slices in zip(*splitted_memos.values()):
+ # Replace entires in memo dict with splitted counterparts.
+ if chunk_size == 1:
+ # Additionally squeeze the input.
+ memo_override = {
+ orig_id: tensor_slice.squeeze(split_axis)
+ for orig_id, tensor_slice in zip(splitted_memos.keys(), tensor_slices)
+ }
+ else:
+ memo_override = {
+ orig_id: tensor_slice
+ for orig_id, tensor_slice in zip(splitted_memos.keys(), tensor_slices)
+ }
+ yield copy.deepcopy(d, {**memo, **memo_override})
diff --git a/ocl/utils/windows.py b/ocl/utils/windows.py
new file mode 100644
index 0000000..74cccf7
--- /dev/null
+++ b/ocl/utils/windows.py
@@ -0,0 +1,44 @@
+import torch
+from torch import nn
+
+from ocl.utils.routing import RoutableMixin
+
+
+class JoinWindows(nn.Module, RoutableMixin):
+ def __init__(self, n_windows: int, size, masks_path: str, key_path: str = "input.__key__"):
+ nn.Module.__init__(self)
+ RoutableMixin.__init__(self, {"masks": masks_path, "keys": key_path})
+ self.n_windows = n_windows
+ self.size = size
+
+ @RoutableMixin.route
+ def forward(self, masks: torch.Tensor, keys: str) -> torch.Tensor:
+ assert len(masks) == self.n_windows
+ keys_split = [key.split("_") for key in keys]
+ pad_left = [int(elems[1]) for elems in keys_split]
+ pad_top = [int(elems[2]) for elems in keys_split]
+
+ target_height, target_width = self.size
+ n_masks = masks.shape[0] * masks.shape[1]
+ height, width = masks.shape[2], masks.shape[3]
+ full_mask = torch.zeros(n_masks, *self.size).to(masks)
+ x = 0
+ y = 0
+ for idx, mask in enumerate(masks):
+ elems = masks.shape[1]
+ x_start = 0 if pad_left[idx] >= 0 else -pad_left[idx]
+ x_end = min(width, target_width - pad_left[idx])
+ y_start = 0 if pad_top[idx] >= 0 else -pad_top[idx]
+ y_end = min(height, target_height - pad_top[idx])
+ cropped = mask[:, y_start:y_end, x_start:x_end]
+ full_mask[
+ idx * elems : (idx + 1) * elems, y : y + cropped.shape[-2], x : x + cropped.shape[-1]
+ ] = cropped
+ x += cropped.shape[-1]
+ if x > target_width:
+ y += cropped.shape[-2]
+ x = 0
+
+ assert torch.all(torch.abs(torch.sum(full_mask, axis=0) - 1) <= 1e-2)
+
+ return full_mask.unsqueeze(0)
diff --git a/ocl/visualization_types.py b/ocl/visualization_types.py
new file mode 100644
index 0000000..730326a
--- /dev/null
+++ b/ocl/visualization_types.py
@@ -0,0 +1,74 @@
+"""Classes for handling different types of visualizations."""
+import dataclasses
+from typing import Any, List, Optional, Union
+
+import matplotlib.pyplot
+import torch
+from torch.utils.tensorboard import SummaryWriter
+from torchtyping import TensorType
+
+
+def dataclass_to_dict(d):
+ return {field.name: getattr(d, field.name) for field in dataclasses.fields(d)}
+
+
+@dataclasses.dataclass
+class Visualization:
+ def add_to_experiment(self, experiment: SummaryWriter, tag: str, global_step: int):
+ pass
+
+
+@dataclasses.dataclass
+class Figure(Visualization):
+ """Matplotlib figure."""
+
+ figure: matplotlib.pyplot.figure
+ close: bool = True
+
+ def add_to_experiment(self, experiment: SummaryWriter, tag: str, global_step: int):
+ experiment.add_figure(**dataclass_to_dict(self), tag=tag, global_step=global_step)
+
+
+@dataclasses.dataclass
+class Image(Visualization):
+ """Single image."""
+
+ img_tensor: torch.Tensor
+ dataformats: str = "CHW"
+
+ def add_to_experiment(self, experiment: SummaryWriter, tag: str, global_step: int):
+ experiment.add_image(**dataclass_to_dict(self), tag=tag, global_step=global_step)
+
+
+@dataclasses.dataclass
+class Images(Visualization):
+ """Batch of images."""
+
+ img_tensor: torch.Tensor
+ dataformats: str = "NCHW"
+
+ def add_to_experiment(self, experiment: SummaryWriter, tag: str, global_step: int):
+ experiment.add_images(**dataclass_to_dict(self), tag=tag, global_step=global_step)
+
+
+@dataclasses.dataclass
+class Video(Visualization):
+ """Batch of videos."""
+
+ vid_tensor: TensorType["batch_size", "frames", "channels", "height", "width"] # noqa: F821
+ fps: Union[int, float] = 4
+
+ def add_to_experiment(self, experiment: SummaryWriter, tag: str, global_step: int):
+ experiment.add_video(**dataclass_to_dict(self), tag=tag, global_step=global_step)
+
+
+class Embedding(Visualization):
+ """Batch of embeddings."""
+
+ mat: TensorType["batch_size", "feature_dim"] # noqa: F821
+ metadata: Optional[List[Any]] = None
+ label_img: Optional[TensorType["batch_size", "channels", "height", "width"]] = None # noqa: F821
+ metadata_header: Optional[List[str]] = None
+
+ def add_to_experiment(self, experiment: SummaryWriter, tag: str, global_step: int):
+ experiment.add_embedding(**dataclass_to_dict(self), tag=tag, global_step=global_step)
diff --git a/ocl/visualizations.py b/ocl/visualizations.py
new file mode 100644
index 0000000..3ca9e4f
--- /dev/null
+++ b/ocl/visualizations.py
@@ -0,0 +1,675 @@
+from typing import Callable, Dict, List, Optional, Tuple
+
+import matplotlib
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+from torchvision import transforms
+from torchvision.ops import masks_to_boxes
+from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks, make_grid
+
+from ocl import consistency, visualization_types
+from ocl.utils.routing import RoutableMixin
+
+
+def _flow_tensor_to_rgb_tensor(flow, flow_scaling_factor=50.0):
+ """Visualizes flow motion image as an RGB image.
+
+ Adapted from github.com/google-research/slot-attention-video/blob/main/savi/lib/preprocessing.py
+
+ Args:
+ flow: A tensor either of shape [..., 2, height, width].
+ flow_scaling_factor: How much to scale flow for visualization.
+
+ Returns:
+ A visualization tensor with the same shape as flow, except with three channels.
+
+ """
+ hypot = lambda a, b: (a**2.0 + b**2.0) ** 0.5 # sqrt(a^2 + b^2)
+ flow = torch.moveaxis(flow, -3, -1)
+ height, width = flow.shape[-3:-1]
+ scaling = flow_scaling_factor / hypot(height, width)
+ x, y = flow[..., 0], flow[..., 1]
+ motion_angle = torch.atan2(y, x)
+ motion_angle = (motion_angle / np.math.pi + 1.0) / 2.0
+ motion_magnitude = hypot(y, x)
+ motion_magnitude = torch.clip(motion_magnitude * scaling, 0.0, 1.0)
+ value_channel = torch.ones_like(motion_angle)
+ flow_hsv = torch.stack([motion_angle, motion_magnitude, value_channel], dim=-1)
+ flow_rbg = matplotlib.colors.hsv_to_rgb(flow_hsv.detach().numpy())
+ flow_rbg = torch.moveaxis(torch.Tensor(flow_rbg), -1, -3)
+ return flow_rbg
+
+
+def _nop(arg):
+ return arg
+
+
+class Image(RoutableMixin):
+ def __init__(
+ self,
+ n_instances: int = 8,
+ n_row: int = 8,
+ denormalization: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
+ as_grid: bool = True,
+ image_path: Optional[str] = None,
+ ):
+ super().__init__({"image": image_path})
+ self.n_instances = n_instances
+ self.n_row = n_row
+ self.denormalization = denormalization if denormalization else _nop
+ self.as_grid = as_grid
+
+ @RoutableMixin.route
+ def __call__(self, image: torch.Tensor):
+ image = self.denormalization(image[: self.n_instances].cpu())
+ if self.as_grid:
+ return visualization_types.Image(make_grid(image, nrow=self.n_row))
+ else:
+ return visualization_types.Images(image)
+
+
+class Video(RoutableMixin):
+ def __init__(
+ self,
+ n_instances: int = 8,
+ n_row: int = 8,
+ denormalization: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
+ as_grid: bool = True,
+ video_path: Optional[str] = None,
+ fps: int = 10,
+ ):
+ super().__init__({"video": video_path})
+ self.n_instances = n_instances
+ self.n_row = n_row
+ self.denormalization = denormalization if denormalization else _nop
+ self.as_grid = as_grid
+ self.fps = fps
+
+ @RoutableMixin.route
+ def __call__(self, video: torch.Tensor):
+ video = video[: self.n_instances].cpu()
+ if self.as_grid:
+ video = torch.stack(
+ [
+ make_grid(self.denormalization(frame.unsqueeze(1)).squeeze(1), nrow=self.n_row)
+ for frame in torch.unbind(video, 1)
+ ],
+ dim=0,
+ ).unsqueeze(0)
+ return visualization_types.Video(video, fps=self.fps)
+
+
+class Mask(RoutableMixin):
+ def __init__(
+ self,
+ n_instances: int = 8,
+ mask_path: Optional[str] = None,
+ fps: int = 10,
+ ):
+ super().__init__({"masks": mask_path})
+ self.n_instances = n_instances
+ self.fps = fps
+
+ @RoutableMixin.route
+ def __call__(self, masks):
+ masks = masks[: self.n_instances].cpu().contiguous()
+ image_shape = masks.shape[-2:]
+ n_objects = masks.shape[-3]
+
+ if masks.dim() == 5:
+ # Handling video data.
+ # bs x frames x objects x H x W
+ mask_vis = masks.transpose(1, 2).contiguous()
+ flattened_masks = mask_vis.flatten(0, 1).unsqueeze(2)
+
+ # Draw masks inverted as they are easier to print.
+ mask_vis = torch.stack(
+ [
+ make_grid(1.0 - masks, nrow=n_objects)
+ for masks in torch.unbind(flattened_masks, 1)
+ ],
+ dim=0,
+ )
+ mask_vis = mask_vis.unsqueeze(0)
+ return visualization_types.Video(mask_vis, fps=self.fps)
+ elif masks.dim() == 4:
+ # Handling image data.
+ # bs x objects x H x W
+ # Monochrome image with single channel.
+ masks = masks.view(-1, 1, *image_shape)
+ # Draw masks inverted as they are easier to print.
+ return visualization_types.Image(make_grid(1.0 - masks, nrow=n_objects))
+
+
+class VisualObject(RoutableMixin):
+ def __init__(
+ self,
+ n_instances: int = 8,
+ denormalization: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
+ object_path: Optional[str] = None,
+ mask_path: Optional[str] = None,
+ fps: int = 10,
+ ):
+ super().__init__({"object_reconstructions": object_path, "object_masks": mask_path})
+ self.n_instances = n_instances
+ self.denormalization = denormalization if denormalization else _nop
+ self.fps = fps
+
+ @RoutableMixin.route
+ def __call__(self, object_reconstructions, object_masks):
+ objects = object_reconstructions[: self.n_instances].cpu()
+ masks = object_masks[: self.n_instances].cpu().contiguous()
+
+ image_shape = objects.shape[-3:]
+ n_objects = objects.shape[-4]
+
+ if objects.dim() == 6:
+ # Handling video data.
+ # bs x frames x objects x C x H x W
+
+ # We need to denormalize prior to constructing the grid, yet the denormalization
+ # method assumes video input. We thus convert a frame into a single frame video and
+ # remove the additional dimension prior to make_grid.
+ # Switch object and frame dimension.
+ object_vis = objects.transpose(1, 2).contiguous()
+ mask_vis = masks.transpose(1, 2).contiguous()
+ flattened_masks = mask_vis.flatten(0, 1).unsqueeze(2)
+ object_vis = self.denormalization(object_vis.flatten(0, 1))
+ # Keep object pixels and apply white background to non-objects parts.
+ object_vis = object_vis * flattened_masks + (1.0 - flattened_masks)
+ object_vis = torch.stack(
+ [
+ make_grid(
+ object_vis_frame,
+ nrow=n_objects,
+ )
+ for object_vis_frame in torch.unbind(object_vis, 1)
+ ],
+ dim=0,
+ )
+ # Add batch dimension as this is required for video input.
+ object_vis = object_vis.unsqueeze(0)
+
+ # Draw masks inverted as they are easier to print.
+ mask_vis = torch.stack(
+ [
+ make_grid(1.0 - masks, nrow=n_objects)
+ for masks in torch.unbind(flattened_masks, 1)
+ ],
+ dim=0,
+ )
+ mask_vis = mask_vis.unsqueeze(0)
+ return {
+ "reconstruction": visualization_types.Video(object_vis, fps=self.fps),
+ "mask": visualization_types.Video(mask_vis, fps=self.fps),
+ }
+ elif objects.dim() == 5:
+ # Handling image data.
+ # bs x objects x C x H x W
+ object_reconstructions = self.denormalization(objects.view(-1, *image_shape))
+ # Monochrome image with single channel.
+ masks = masks.view(-1, 1, *image_shape[1:])
+ # Save object reconstructions as RGBA image. make_grid does not support RGBA input, thus
+ # we combine the channels later. For the masks we need to pad with 1 as we want the
+ # borders between images to remain visible (i.e. alpha value of 1.)
+ masks_grid = make_grid(masks, nrow=n_objects, pad_value=1.0)
+ object_grid = make_grid(object_reconstructions, nrow=n_objects)
+ # masks_grid expands the image to three channels, which we don't need. Only keep one, and
+ # use it as the alpha channel. After make_grid the tensor has the shape C X W x H.
+ object_grid = torch.cat((object_grid, masks_grid[:1]), dim=0)
+
+ return {
+ "reconstruction": visualization_types.Image(object_grid),
+ # Draw masks inverted as they are easier to print.
+ "mask": visualization_types.Image(make_grid(1.0 - masks, nrow=n_objects)),
+ }
+
+
+class ConsistencyMask(RoutableMixin):
+ def __init__(
+ self,
+ matcher: consistency.HungarianMatcher,
+ mask_path: Optional[str] = None,
+ mask_target_path: Optional[str] = None,
+ params_path: Optional[str] = None,
+ ):
+ super().__init__(
+ {"mask": mask_path, "mask_target": mask_target_path, "cropping_params": params_path}
+ )
+ self.matcher = matcher
+
+ @RoutableMixin.route
+ def __call__(self, mask: torch.Tensor, mask_target: torch.Tensor, cropping_params: torch.Tensor):
+ _, _, size, _ = mask.shape
+ mask_one_hot = self._to_binary_mask(mask)
+ mask_target = self.crop_views(mask_target, cropping_params, size)
+ mask_target_one_hot = self._to_binary_mask(mask_target)
+ _ = self.matcher(mask_one_hot, mask_target_one_hot)
+ return {
+ "costs": visualization_types.Image(
+ make_grid(-self.matcher.costs, nrow=8, pad_value=0.9)
+ ),
+ }
+
+ @staticmethod
+ def _to_binary_mask(masks: torch.Tensor):
+ _, n_objects, _, _ = masks.shape
+ m_lables = masks.argmax(dim=1)
+ mask_one_hot = torch.nn.functional.one_hot(m_lables, n_objects)
+ return mask_one_hot.permute(0, 3, 1, 2)
+
+ def crop_views(self, view: torch.Tensor, param: torch.Tensor, size: int):
+ return torch.cat([self.crop_maping(v, p, size) for v, p in zip(view, param)])
+
+ @staticmethod
+ def crop_maping(view: torch.Tensor, p: torch.Tensor, size: int):
+ p = tuple(p.cpu().numpy().astype(int))
+ return transforms.functional.resized_crop(view, *p, size=(size, size))[None]
+
+
+class Segmentation(RoutableMixin):
+ def __init__(
+ self,
+ n_instances: int = 8,
+ denormalization: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
+ image_path: Optional[str] = None,
+ mask_path: Optional[str] = None,
+ ):
+ super().__init__({"image": image_path, "mask": mask_path})
+ self.n_instances = n_instances
+ self.denormalization = denormalization if denormalization else _nop
+ self._cmap_cache: Dict[int, List[Tuple[int, int, int]]] = {}
+
+ def _get_cmap(self, num_classes: int) -> List[Tuple[int, int, int]]:
+ if num_classes in self._cmap_cache:
+ return self._cmap_cache[num_classes]
+
+ from matplotlib import cm
+
+ if num_classes <= 20:
+ mpl_cmap = cm.get_cmap("tab20", num_classes)(range(num_classes))
+ else:
+ mpl_cmap = cm.get_cmap("turbo", num_classes)(range(num_classes))
+
+ cmap = [tuple((255 * cl[:3]).astype(int)) for cl in mpl_cmap]
+ self._cmap_cache[num_classes] = cmap
+ return cmap
+
+ @RoutableMixin.route
+ def __call__(
+ self, image: torch.Tensor, mask: torch.Tensor
+ ) -> Optional[visualization_types.Visualization]:
+ image = image[: self.n_instances].cpu()
+ mask = mask[: self.n_instances].cpu().contiguous()
+ if image.dim() == 4: # Only support image data at the moment.
+ input_image = self.denormalization(image)
+ n_objects = mask.shape[1]
+
+ masks_argmax = mask.argmax(dim=1)[:, None]
+ classes = torch.arange(n_objects)[None, :, None, None].to(masks_argmax)
+ masks_one_hot = masks_argmax == classes
+
+ cmap = self._get_cmap(n_objects)
+ masks_on_image = torch.stack(
+ [
+ draw_segmentation_masks(
+ (255 * img).to(torch.uint8), mask, alpha=0.75, colors=cmap
+ )
+ for img, mask in zip(input_image.to("cpu"), masks_one_hot.to("cpu"))
+ ]
+ )
+
+ return visualization_types.Image(make_grid(masks_on_image, nrow=8))
+ return None
+
+
+class Flow(RoutableMixin):
+ def __init__(
+ self,
+ n_instances: int = 8,
+ n_row: int = 8,
+ denormalization: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
+ as_grid: bool = True,
+ flow_path: Optional[str] = None,
+ ):
+ super().__init__({"flow": flow_path})
+ self.n_instances = n_instances
+ self.n_row = n_row
+ self.denormalization = denormalization if denormalization else _nop
+ self.as_grid = as_grid
+
+ @RoutableMixin.route
+ def __call__(self, flow: torch.Tensor):
+ flow = self.denormalization(flow[: self.n_instances].cpu())
+ flow = _flow_tensor_to_rgb_tensor(flow)
+ if self.as_grid:
+ return visualization_types.Image(make_grid(flow, nrow=self.n_row))
+ else:
+ return visualization_types.Images(flow)
+
+
+color_list = [
+ "red",
+ "blue",
+ "green",
+ "yellow",
+ "pink",
+ "black",
+ "#614051",
+ "#cd7f32",
+ "#008b8b",
+ "#556b2f",
+ "#ffbf00",
+]
+
+
+class TrackedObject(RoutableMixin):
+ def __init__(
+ self,
+ n_clips: int = 3,
+ n_row: int = 8,
+ denormalization: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
+ video_path: Optional[str] = None,
+ bbox_path: Optional[str] = None,
+ idx_path: Optional[str] = None,
+ ):
+ super().__init__({"video": video_path, "bbox": bbox_path, "idx": idx_path})
+ self.n_clips = n_clips
+ self.n_row = n_row
+ self.denormalization = denormalization if denormalization else _nop
+
+ @RoutableMixin.route
+ def __call__(
+ self, video: torch.Tensor, bbox: torch.Tensor, idx: torch.Tensor
+ ) -> Optional[visualization_types.Visualization]:
+ video = video[: self.n_clips].cpu()
+ num_frames = video.shape[1]
+
+ bbox = bbox[: self.n_clips].to(torch.uint8).cpu()
+ idx = idx[: self.n_clips].cpu()
+
+ rendered_video = torch.zeros_like(video)
+ num_color = len(color_list)
+
+ for cidx in range(self.n_clips):
+ for fidx in range(num_frames):
+ if cidx >= idx.shape[0] or fidx >= idx.shape[1]:
+ break
+ cur_obj_idx = idx[cidx, fidx]
+ valid_index = cur_obj_idx > -1
+ cur_obj_idx = cur_obj_idx[valid_index].to(torch.int)
+ cur_color_list = [
+ color_list[obj_idx % num_color] for obj_idx in cur_obj_idx.numpy().tolist()
+ ]
+ frame = (video[cidx, fidx] * 256).to(torch.uint8)
+ frame = draw_bounding_boxes(
+ frame, bbox[cidx, fidx][valid_index], colors=cur_color_list
+ )
+ rendered_video[cidx, fidx] = frame
+
+ rendered_video = (
+ torch.stack(
+ [
+ make_grid(self.denormalization(frame.unsqueeze(1)).squeeze(1), nrow=self.n_row)
+ for frame in torch.unbind(rendered_video, 1)
+ ],
+ dim=0,
+ )
+ .unsqueeze(0)
+ .to(torch.float32)
+ )
+
+ return visualization_types.Video(rendered_video)
+
+
+class TrackedObject_from_Mask(RoutableMixin):
+ def __init__(
+ self,
+ n_clips: int = 3,
+ n_row: int = 8,
+ denormalization: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
+ fps: int = 10,
+ video_path: Optional[str] = None,
+ mask_path: Optional[str] = None,
+ ):
+ super().__init__({"video": video_path, "object_masks": mask_path})
+ self.n_clips = n_clips
+ self.n_row = n_row
+ self.denormalization = denormalization if denormalization else _nop
+ self.fps = fps
+
+ @RoutableMixin.route
+ def __call__(
+ self,
+ video: torch.Tensor,
+ object_masks: torch.Tensor,
+ ) -> Optional[visualization_types.Visualization]:
+ video = video[: self.n_clips].cpu()
+ num_frames = video.shape[1]
+
+ masks = object_masks[: self.n_clips]
+ B, F, C, h, w = masks.shape
+ masks = masks > 0.5
+
+ rendered_video = torch.zeros_like(video)
+
+ for cidx in range(self.n_clips):
+ for fidx in range(num_frames):
+
+ idx = []
+ for i in range(C):
+ if torch.sum(masks[cidx, fidx][i]) != 0:
+ idx.append(i)
+ bbox = masks_to_boxes(masks[cidx, fidx][np.array(idx)]).cpu().contiguous()
+
+ for id in idx:
+ pred_h = bbox[id][2] - bbox[id][0]
+ pred_w = bbox[id][3] - bbox[id][1]
+ thres = h * w * 0.2
+ if pred_h * pred_w >= thres:
+ idx.remove(id)
+ cur_obj_idx = np.array(idx)
+ cur_color_list = [color_list[obj_idx] for obj_idx in idx]
+ frame = (video[cidx, fidx] * 256).to(torch.uint8)
+ frame = draw_bounding_boxes(frame, bbox[cur_obj_idx], colors=cur_color_list)
+ rendered_video[cidx, fidx] = frame
+
+ rendered_video = (
+ torch.stack(
+ [
+ make_grid(self.denormalization(frame.unsqueeze(1)).squeeze(1), nrow=self.n_row)
+ for frame in torch.unbind(rendered_video, 1)
+ ],
+ dim=0,
+ )
+ .unsqueeze(0)
+ .to(torch.float32)
+ / 256
+ )
+
+ return visualization_types.Video(rendered_video, fps=self.fps)
+
+
+def masks_to_bboxes_xyxy(masks: torch.Tensor, empty_value: float = -1.0) -> torch.Tensor:
+ """Compute bounding boxes around the provided masks.
+
+ Adapted from DETR: https://github.com/facebookresearch/detr/blob/main/util/box_ops.py
+
+ Args:
+ masks: Tensor of shape (N, H, W), where N is the number of masks, H and W are the spatial
+ dimensions.
+ empty_value: Value bounding boxes should contain for empty masks.
+
+ Returns:
+ Tensor of shape (N, 4), containing bounding boxes in (x1, y1, x2, y2) format, where (x1, y1)
+ is the coordinate of top-left corner and (x2, y2) is the coordinate of the bottom-right
+ corner (inclusive) in pixel coordinates. If mask is empty, all coordinates contain
+ `empty_value` instead.
+ """
+ masks = masks.bool()
+ if masks.numel() == 0:
+ return torch.zeros((0, 4), device=masks.device)
+
+ large_value = 1e8
+ inv_mask = ~masks
+
+ h, w = masks.shape[-2:]
+
+ y = torch.arange(0, h, dtype=torch.float, device=masks.device)
+ x = torch.arange(0, w, dtype=torch.float, device=masks.device)
+ y, x = torch.meshgrid(y, x, indexing="ij")
+
+ x_mask = masks * x.unsqueeze(0)
+ x_max = x_mask.flatten(1).max(-1)[0]
+ x_min = x_mask.masked_fill(inv_mask, large_value).flatten(1).min(-1)[0]
+
+ y_mask = masks * y.unsqueeze(0)
+ y_max = y_mask.flatten(1).max(-1)[0]
+ y_min = y_mask.masked_fill(inv_mask, large_value).flatten(1).min(-1)[0]
+
+ bboxes = torch.stack((x_min, y_min, x_max, y_max), dim=1) # x1y1x2y2
+ bboxes[x_min == large_value] = empty_value
+ return bboxes
+
+
+class ObjectMOT(RoutableMixin):
+ def __init__(
+ self,
+ n_clips: int = 3,
+ n_row: int = 8,
+ denormalization: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
+ video_path: Optional[str] = None,
+ mask_path: Optional[str] = None,
+ ):
+ super().__init__({"video": video_path, "object_masks": mask_path})
+ self.n_clips = n_clips
+ self.n_row = n_row
+ self.denormalization = denormalization if denormalization else _nop
+
+ @RoutableMixin.route
+ def __call__(
+ self,
+ video: torch.Tensor,
+ object_masks: torch.Tensor,
+ ) -> Optional[visualization_types.Visualization]:
+ video = video[: self.n_clips].cpu()
+ num_frames = video.shape[1]
+
+ masks = object_masks[: self.n_clips].cpu().contiguous()
+ B, F, C, h, w = masks.shape # [5, 6, 11, 64, 64]
+ masks = masks.flatten(0, 1)
+ masks = masks > 0.5
+ bbox = masks_to_bboxes_xyxy(masks.flatten(0, 1)).unflatten(0, (B, F, C))
+
+ rendered_video = torch.zeros_like(video)
+
+ color_list = [
+ "red",
+ "blue",
+ "green",
+ "yellow",
+ "pink",
+ "black",
+ "#614051",
+ "#cd7f32",
+ "#008b8b",
+ "#556b2f",
+ "#ffbf00",
+ "white",
+ "orange",
+ "gray",
+ "#ffbf00",
+ ]
+
+ for cidx in range(self.n_clips):
+ for fidx in range(num_frames):
+ cur_obj_box = bbox[cidx, fidx][:, 0] != -1.0
+ cur_obj_idx = cur_obj_box.nonzero()[:, 0].detach().cpu().numpy()
+ idx = cur_obj_idx.tolist()
+ cur_obj_idx = np.array(idx)
+ cur_color_list = [color_list[obj_idx] for obj_idx in idx]
+ frame = (video[cidx, fidx] * 256).to(torch.uint8)
+ frame = draw_bounding_boxes(
+ frame, bbox[cidx, fidx][cur_obj_idx], colors=cur_color_list
+ )
+ rendered_video[cidx, fidx] = frame
+
+ rendered_video = (
+ torch.stack(
+ [
+ make_grid(self.denormalization(frame.unsqueeze(1)).squeeze(1), nrow=self.n_row)
+ for frame in torch.unbind(rendered_video, 1)
+ ],
+ dim=0,
+ )
+ .unsqueeze(0)
+ .to(torch.float32)
+ / 256
+ )
+
+ return visualization_types.Video(rendered_video)
+
+
+class TextToImageMatching(RoutableMixin):
+ def __init__(
+ self,
+ denormalization: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
+ row_is_image: bool = True,
+ n_instances: Optional[int] = None,
+ image_path: Optional[str] = None,
+ text_path: Optional[str] = None,
+ similarities_path: Optional[str] = None,
+ ):
+ super().__init__({"image": image_path, "text": text_path, "similarities": similarities_path})
+ self.denormalization = denormalization if denormalization else _nop
+ self.row_is_image = row_is_image
+ self.n_instances = n_instances
+
+ @RoutableMixin.route
+ def __call__(
+ self, image: torch.Tensor, text: List[str], similarities: torch.Tensor
+ ) -> Optional[visualization_types.Visualization]:
+ n_images = len(image)
+ n_texts = len(text)
+
+ image = image.detach()
+ if self.row_is_image:
+ # Code assumes that each rows in the similarity matrix correspond to a single text.
+ similarities = similarities.T.detach()
+ else:
+ similarities = similarities.detach()
+
+ assert n_texts == similarities.shape[0]
+ assert n_images == similarities.shape[1]
+
+ if self.n_instances:
+ n_images = min(self.n_instances, n_images)
+ n_texts = min(self.n_instances, n_texts)
+
+ image = (
+ torch.clamp((255 * self.denormalization(image[:n_images])), 0, 255)
+ .to(torch.uint8)
+ .permute(0, 2, 3, 1)
+ .cpu()
+ )
+ text = text[:n_texts]
+ similarities = similarities[:n_texts, :n_images].cpu()
+
+ fig, ax = plt.subplots(1, 1, figsize=(20, 14))
+ ax.imshow(similarities, vmin=0.1, vmax=0.3)
+
+ ax.set_yticks(range(n_texts), text, fontsize=18, wrap=True)
+ ax.set_xticks([])
+ for i, cur_image in enumerate(image):
+ ax.imshow(cur_image, extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin="lower")
+ for x in range(similarities.shape[1]):
+ for y in range(similarities.shape[0]):
+ ax.text(x, y, f"{similarities[y, x]:.2f}", ha="center", va="center", size=12)
+
+ for side in ["left", "top", "right", "bottom"]:
+ ax.spines[side].set_visible(False)
+
+ ax.set_xlim([-0.5, n_images - 0.5])
+ ax.set_ylim([n_texts + 0.5, -2])
+ return visualization_types.Figure(fig)
diff --git a/poetry.lock b/poetry.lock
new file mode 100644
index 0000000..fd1a67c
--- /dev/null
+++ b/poetry.lock
@@ -0,0 +1,4714 @@
+[[package]]
+name = "absl-py"
+version = "1.1.0"
+description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py."
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[[package]]
+name = "aiohttp"
+version = "3.8.1"
+description = "Async http client/server framework (asyncio)"
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[package.dependencies]
+aiosignal = ">=1.1.2"
+async-timeout = ">=4.0.0a3,<5.0"
+asynctest = {version = "0.13.0", markers = "python_version < \"3.8\""}
+attrs = ">=17.3.0"
+charset-normalizer = ">=2.0,<3.0"
+frozenlist = ">=1.1.1"
+multidict = ">=4.5,<7.0"
+typing-extensions = {version = ">=3.7.4", markers = "python_version < \"3.8\""}
+yarl = ">=1.0,<2.0"
+
+[package.extras]
+speedups = ["Brotli", "aiodns", "cchardet"]
+
+[[package]]
+name = "aiosignal"
+version = "1.2.0"
+description = "aiosignal: a list of registered asynchronous callbacks"
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[package.dependencies]
+frozenlist = ">=1.1.0"
+
+[[package]]
+name = "alembic"
+version = "1.8.1"
+description = "A database migration tool for SQLAlchemy."
+category = "main"
+optional = true
+python-versions = ">=3.7"
+
+[package.dependencies]
+importlib-metadata = {version = "*", markers = "python_version < \"3.9\""}
+importlib-resources = {version = "*", markers = "python_version < \"3.9\""}
+Mako = "*"
+SQLAlchemy = ">=1.3.0"
+
+[package.extras]
+tz = ["python-dateutil"]
+
+[[package]]
+name = "antlr4-python3-runtime"
+version = "4.9.3"
+description = "ANTLR 4.9.3 runtime for Python 3.7"
+category = "main"
+optional = false
+python-versions = "*"
+
+[[package]]
+name = "anyio"
+version = "3.6.2"
+description = "High level compatibility layer for multiple asynchronous event loop implementations"
+category = "main"
+optional = false
+python-versions = ">=3.6.2"
+
+[package.dependencies]
+idna = ">=2.8"
+sniffio = ">=1.1"
+typing-extensions = {version = "*", markers = "python_version < \"3.8\""}
+
+[package.extras]
+doc = ["packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"]
+test = ["contextlib2", "coverage[toml] (>=4.5)", "hypothesis (>=4.0)", "mock (>=4)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (<0.15)", "uvloop (>=0.15)"]
+trio = ["trio (>=0.16,<0.22)"]
+
+[[package]]
+name = "appnope"
+version = "0.1.3"
+description = "Disable App Nap on macOS >= 10.9"
+category = "main"
+optional = false
+python-versions = "*"
+
+[[package]]
+name = "argon2-cffi"
+version = "21.3.0"
+description = "The secure Argon2 password hashing algorithm."
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[package.dependencies]
+argon2-cffi-bindings = "*"
+typing-extensions = {version = "*", markers = "python_version < \"3.8\""}
+
+[package.extras]
+dev = ["cogapp", "coverage[toml] (>=5.0.2)", "furo", "hypothesis", "pre-commit", "pytest", "sphinx", "sphinx-notfound-page", "tomli"]
+docs = ["furo", "sphinx", "sphinx-notfound-page"]
+tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pytest"]
+
+[[package]]
+name = "argon2-cffi-bindings"
+version = "21.2.0"
+description = "Low-level CFFI bindings for Argon2"
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[package.dependencies]
+cffi = ">=1.0.1"
+
+[package.extras]
+dev = ["cogapp", "pre-commit", "pytest", "wheel"]
+tests = ["pytest"]
+
+[[package]]
+name = "async-timeout"
+version = "4.0.2"
+description = "Timeout context manager for asyncio programs"
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[package.dependencies]
+typing-extensions = {version = ">=3.6.5", markers = "python_version < \"3.8\""}
+
+[[package]]
+name = "asynctest"
+version = "0.13.0"
+description = "Enhance the standard unittest package with features for testing asyncio libraries"
+category = "main"
+optional = false
+python-versions = ">=3.5"
+
+[[package]]
+name = "atomicwrites"
+version = "1.4.0"
+description = "Atomic file writes."
+category = "dev"
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
+
+[[package]]
+name = "attrs"
+version = "21.4.0"
+description = "Classes Without Boilerplate"
+category = "main"
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
+
+[package.extras]
+dev = ["cloudpickle", "coverage[toml] (>=5.0.2)", "furo", "hypothesis", "mypy", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "six", "sphinx", "sphinx-notfound-page", "zope.interface"]
+docs = ["furo", "sphinx", "sphinx-notfound-page", "zope.interface"]
+tests = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "six", "zope.interface"]
+tests_no_zope = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "six"]
+
+[[package]]
+name = "awscli"
+version = "1.25.22"
+description = "Universal Command Line Environment for AWS."
+category = "main"
+optional = false
+python-versions = ">= 3.7"
+
+[package.dependencies]
+botocore = "1.27.22"
+colorama = ">=0.2.5,<0.4.5"
+docutils = ">=0.10,<0.17"
+PyYAML = ">=3.10,<5.5"
+rsa = ">=3.1.2,<4.8"
+s3transfer = ">=0.6.0,<0.7.0"
+
+[[package]]
+name = "awscrt"
+version = "0.13.8"
+description = "A common runtime for AWS Python projects"
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[[package]]
+name = "backcall"
+version = "0.2.0"
+description = "Specifications for callback functions passed in to an API"
+category = "main"
+optional = false
+python-versions = "*"
+
+[[package]]
+name = "beautifulsoup4"
+version = "4.12.2"
+description = "Screen-scraping library"
+category = "main"
+optional = false
+python-versions = ">=3.6.0"
+
+[package.dependencies]
+soupsieve = ">1.2"
+
+[package.extras]
+html5lib = ["html5lib"]
+lxml = ["lxml"]
+
+[[package]]
+name = "black"
+version = "22.6.0"
+description = "The uncompromising code formatter."
+category = "dev"
+optional = false
+python-versions = ">=3.6.2"
+
+[package.dependencies]
+click = ">=8.0.0"
+mypy-extensions = ">=0.4.3"
+pathspec = ">=0.9.0"
+platformdirs = ">=2"
+tomli = {version = ">=1.1.0", markers = "python_full_version < \"3.11.0a7\""}
+typed-ast = {version = ">=1.4.2", markers = "python_version < \"3.8\" and implementation_name == \"cpython\""}
+typing-extensions = {version = ">=3.10.0.0", markers = "python_version < \"3.10\""}
+
+[package.extras]
+colorama = ["colorama (>=0.4.3)"]
+d = ["aiohttp (>=3.7.4)"]
+jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
+uvloop = ["uvloop (>=0.15.2)"]
+
+[[package]]
+name = "bleach"
+version = "6.0.0"
+description = "An easy safelist-based HTML-sanitizing tool."
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.dependencies]
+six = ">=1.9.0"
+webencodings = "*"
+
+[package.extras]
+css = ["tinycss2 (>=1.1.0,<1.2)"]
+
+[[package]]
+name = "botocore"
+version = "1.27.22"
+description = "Low-level, data-driven core of boto 3."
+category = "main"
+optional = false
+python-versions = ">= 3.7"
+
+[package.dependencies]
+awscrt = {version = "0.13.8", optional = true, markers = "extra == \"crt\""}
+jmespath = ">=0.7.1,<2.0.0"
+python-dateutil = ">=2.1,<3.0.0"
+urllib3 = ">=1.25.4,<1.27"
+
+[package.extras]
+crt = ["awscrt (==0.13.8)"]
+
+[[package]]
+name = "braceexpand"
+version = "0.1.7"
+description = "Bash-style brace expansion for Python"
+category = "main"
+optional = false
+python-versions = "*"
+
+[[package]]
+name = "cachetools"
+version = "5.2.0"
+description = "Extensible memoizing collections and decorators"
+category = "main"
+optional = false
+python-versions = "~=3.7"
+
+[[package]]
+name = "certifi"
+version = "2022.6.15"
+description = "Python package for providing Mozilla's CA Bundle."
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[[package]]
+name = "cffi"
+version = "1.15.1"
+description = "Foreign Function Interface for Python calling C code."
+category = "main"
+optional = false
+python-versions = "*"
+
+[package.dependencies]
+pycparser = "*"
+
+[[package]]
+name = "cfgv"
+version = "3.3.1"
+description = "Validate configuration and produce human readable error messages."
+category = "dev"
+optional = false
+python-versions = ">=3.6.1"
+
+[[package]]
+name = "charset-normalizer"
+version = "2.1.0"
+description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet."
+category = "main"
+optional = false
+python-versions = ">=3.6.0"
+
+[package.extras]
+unicode_backport = ["unicodedata2"]
+
+[[package]]
+name = "click"
+version = "8.1.3"
+description = "Composable command line interface toolkit"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.dependencies]
+colorama = {version = "*", markers = "platform_system == \"Windows\""}
+importlib-metadata = {version = "*", markers = "python_version < \"3.8\""}
+
+[[package]]
+name = "cloudpickle"
+version = "2.1.0"
+description = "Extended pickling support for Python objects"
+category = "main"
+optional = true
+python-versions = ">=3.6"
+
+[[package]]
+name = "colorama"
+version = "0.4.4"
+description = "Cross-platform colored terminal text."
+category = "main"
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
+
+[[package]]
+name = "cycler"
+version = "0.11.0"
+description = "Composable style cycles"
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[[package]]
+name = "databricks-cli"
+version = "0.17.3"
+description = "A command line interface for Databricks"
+category = "main"
+optional = true
+python-versions = "*"
+
+[package.dependencies]
+click = ">=7.0"
+oauthlib = ">=3.1.0"
+pyjwt = ">=1.7.0"
+requests = ">=2.17.3"
+six = ">=1.10.0"
+tabulate = ">=0.7.7"
+
+[[package]]
+name = "debugpy"
+version = "1.6.7"
+description = "An implementation of the Debug Adapter Protocol for Python"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[[package]]
+name = "decorator"
+version = "4.4.2"
+description = "Decorators for Humans"
+category = "main"
+optional = false
+python-versions = ">=2.6, !=3.0.*, !=3.1.*"
+
+[[package]]
+name = "decord"
+version = "0.6.0"
+description = "Decord Video Loader"
+category = "main"
+optional = false
+python-versions = "*"
+
+[package.dependencies]
+numpy = ">=1.14.0"
+
+[[package]]
+name = "defusedxml"
+version = "0.7.1"
+description = "XML bomb protection for Python stdlib modules"
+category = "main"
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
+
+[[package]]
+name = "distlib"
+version = "0.3.4"
+description = "Distribution utilities"
+category = "dev"
+optional = false
+python-versions = "*"
+
+[[package]]
+name = "docker"
+version = "6.0.0"
+description = "A Python library for the Docker Engine API."
+category = "main"
+optional = true
+python-versions = ">=3.7"
+
+[package.dependencies]
+packaging = ">=14.0"
+pywin32 = {version = ">=304", markers = "sys_platform == \"win32\""}
+requests = ">=2.26.0"
+urllib3 = ">=1.26.0"
+websocket-client = ">=0.32.0"
+
+[package.extras]
+ssh = ["paramiko (>=2.4.3)"]
+
+[[package]]
+name = "docutils"
+version = "0.16"
+description = "Docutils -- Python Documentation Utilities"
+category = "main"
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
+
+[[package]]
+name = "einops"
+version = "0.6.0"
+description = "A new flavour of deep learning operations"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[[package]]
+name = "entrypoints"
+version = "0.4"
+description = "Discover and load entry points from installed packages."
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[[package]]
+name = "fastjsonschema"
+version = "2.16.3"
+description = "Fastest Python implementation of JSON schema"
+category = "main"
+optional = false
+python-versions = "*"
+
+[package.extras]
+devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benchmark", "pytest-cache", "validictory"]
+
+[[package]]
+name = "filelock"
+version = "3.7.1"
+description = "A platform independent file lock."
+category = "dev"
+optional = false
+python-versions = ">=3.7"
+
+[package.extras]
+docs = ["furo (>=2021.8.17b43)", "sphinx (>=4.1)", "sphinx-autodoc-typehints (>=1.12)"]
+testing = ["covdefaults (>=1.2.0)", "coverage (>=4)", "pytest (>=4)", "pytest-cov", "pytest-timeout (>=1.4.2)"]
+
+[[package]]
+name = "flake8"
+version = "4.0.1"
+description = "the modular source code checker: pep8 pyflakes and co"
+category = "dev"
+optional = false
+python-versions = ">=3.6"
+
+[package.dependencies]
+importlib-metadata = {version = "<4.3", markers = "python_version < \"3.8\""}
+mccabe = ">=0.6.0,<0.7.0"
+pycodestyle = ">=2.8.0,<2.9.0"
+pyflakes = ">=2.4.0,<2.5.0"
+
+[[package]]
+name = "flake8-bugbear"
+version = "22.7.1"
+description = "A plugin for flake8 finding likely bugs and design problems in your program. Contains warnings that don't belong in pyflakes and pycodestyle."
+category = "dev"
+optional = false
+python-versions = ">=3.6"
+
+[package.dependencies]
+attrs = ">=19.2.0"
+flake8 = ">=3.0.0"
+
+[package.extras]
+dev = ["coverage", "hypothesis", "hypothesmith (>=0.2)", "pre-commit"]
+
+[[package]]
+name = "flake8-docstrings"
+version = "1.6.0"
+description = "Extension for flake8 which uses pydocstyle to check docstrings"
+category = "dev"
+optional = false
+python-versions = "*"
+
+[package.dependencies]
+flake8 = ">=3"
+pydocstyle = ">=2.1"
+
+[[package]]
+name = "flake8-isort"
+version = "4.1.1"
+description = "flake8 plugin that integrates isort ."
+category = "dev"
+optional = false
+python-versions = "*"
+
+[package.dependencies]
+flake8 = ">=3.2.1,<5"
+isort = ">=4.3.5,<6"
+testfixtures = ">=6.8.0,<7"
+
+[package.extras]
+test = ["pytest-cov"]
+
+[[package]]
+name = "flake8-tidy-imports"
+version = "4.8.0"
+description = "A flake8 plugin that helps you write tidier imports."
+category = "dev"
+optional = false
+python-versions = ">=3.7"
+
+[package.dependencies]
+flake8 = ">=3.8.0"
+importlib-metadata = {version = "*", markers = "python_version < \"3.8\""}
+
+[[package]]
+name = "Flask"
+version = "2.1.3"
+description = "A simple framework for building complex web applications."
+category = "main"
+optional = true
+python-versions = ">=3.7"
+
+[package.dependencies]
+click = ">=8.0"
+importlib-metadata = {version = ">=3.6.0", markers = "python_version < \"3.10\""}
+itsdangerous = ">=2.0"
+Jinja2 = ">=3.0"
+Werkzeug = ">=2.0"
+
+[package.extras]
+async = ["asgiref (>=3.2)"]
+dotenv = ["python-dotenv"]
+
+[[package]]
+name = "fonttools"
+version = "4.33.3"
+description = "Tools to manipulate font files"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.extras]
+all = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "fs (>=2.2.0,<3)", "lxml (>=4.0,<5)", "lz4 (>=1.7.4.2)", "matplotlib", "munkres", "scipy", "skia-pathops (>=0.5.0)", "sympy", "uharfbuzz (>=0.23.0)", "unicodedata2 (>=14.0.0)", "xattr", "zopfli (>=0.1.4)"]
+graphite = ["lz4 (>=1.7.4.2)"]
+interpolatable = ["munkres", "scipy"]
+lxml = ["lxml (>=4.0,<5)"]
+pathops = ["skia-pathops (>=0.5.0)"]
+plot = ["matplotlib"]
+repacker = ["uharfbuzz (>=0.23.0)"]
+symfont = ["sympy"]
+type1 = ["xattr"]
+ufo = ["fs (>=2.2.0,<3)"]
+unicode = ["unicodedata2 (>=14.0.0)"]
+woff = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "zopfli (>=0.1.4)"]
+
+[[package]]
+name = "frozenlist"
+version = "1.3.0"
+description = "A list-like structure which implements collections.abc.MutableSequence"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[[package]]
+name = "fsspec"
+version = "2022.7.1"
+description = "File-system specification"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.dependencies]
+aiohttp = {version = "*", optional = true, markers = "extra == \"http\""}
+requests = {version = "*", optional = true, markers = "extra == \"http\""}
+
+[package.extras]
+abfs = ["adlfs"]
+adl = ["adlfs"]
+arrow = ["pyarrow (>=1)"]
+dask = ["dask", "distributed"]
+dropbox = ["dropbox", "dropboxdrivefs", "requests"]
+entrypoints = ["importlib-metadata"]
+fuse = ["fusepy"]
+gcs = ["gcsfs"]
+git = ["pygit2"]
+github = ["requests"]
+gs = ["gcsfs"]
+gui = ["panel"]
+hdfs = ["pyarrow (>=1)"]
+http = ["aiohttp", "requests"]
+libarchive = ["libarchive-c"]
+oci = ["ocifs"]
+s3 = ["s3fs"]
+sftp = ["paramiko"]
+smb = ["smbprotocol"]
+ssh = ["paramiko"]
+tqdm = ["tqdm"]
+
+[[package]]
+name = "ftfy"
+version = "6.1.1"
+description = "Fixes mojibake and other problems with Unicode, after the fact"
+category = "main"
+optional = true
+python-versions = ">=3.7,<4"
+
+[package.dependencies]
+wcwidth = ">=0.2.5"
+
+[[package]]
+name = "gitdb"
+version = "4.0.9"
+description = "Git Object Database"
+category = "main"
+optional = true
+python-versions = ">=3.6"
+
+[package.dependencies]
+smmap = ">=3.0.1,<6"
+
+[[package]]
+name = "GitPython"
+version = "3.1.29"
+description = "GitPython is a python library used to interact with Git repositories"
+category = "main"
+optional = true
+python-versions = ">=3.7"
+
+[package.dependencies]
+gitdb = ">=4.0.1,<5"
+typing-extensions = {version = ">=3.7.4.3", markers = "python_version < \"3.8\""}
+
+[[package]]
+name = "google-auth"
+version = "2.9.0"
+description = "Google Authentication Library"
+category = "main"
+optional = false
+python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*"
+
+[package.dependencies]
+cachetools = ">=2.0.0,<6.0"
+pyasn1-modules = ">=0.2.1"
+rsa = {version = ">=3.1.4,<5", markers = "python_version >= \"3.6\""}
+six = ">=1.9.0"
+
+[package.extras]
+aiohttp = ["aiohttp (>=3.6.2,<4.0.0dev)", "requests (>=2.20.0,<3.0.0dev)"]
+enterprise_cert = ["cryptography (==36.0.2)", "pyopenssl (==22.0.0)"]
+pyopenssl = ["pyopenssl (>=20.0.0)"]
+reauth = ["pyu2f (>=0.1.5)"]
+
+[[package]]
+name = "google-auth-oauthlib"
+version = "0.4.6"
+description = "Google Authentication Library"
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[package.dependencies]
+google-auth = ">=1.0.0"
+requests-oauthlib = ">=0.7.0"
+
+[package.extras]
+tool = ["click (>=6.0.0)"]
+
+[[package]]
+name = "greenlet"
+version = "1.1.3.post0"
+description = "Lightweight in-process concurrent programming"
+category = "main"
+optional = true
+python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*"
+
+[package.extras]
+docs = ["Sphinx"]
+
+[[package]]
+name = "grpcio"
+version = "1.47.0"
+description = "HTTP/2-based RPC framework"
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[package.dependencies]
+six = ">=1.5.2"
+
+[package.extras]
+protobuf = ["grpcio-tools (>=1.47.0)"]
+
+[[package]]
+name = "gunicorn"
+version = "20.1.0"
+description = "WSGI HTTP Server for UNIX"
+category = "main"
+optional = true
+python-versions = ">=3.5"
+
+[package.dependencies]
+setuptools = ">=3.0"
+
+[package.extras]
+eventlet = ["eventlet (>=0.24.1)"]
+gevent = ["gevent (>=1.4.0)"]
+setproctitle = ["setproctitle"]
+tornado = ["tornado (>=0.2)"]
+
+[[package]]
+name = "hydra-core"
+version = "1.2.0"
+description = "A framework for elegantly configuring complex applications"
+category = "main"
+optional = false
+python-versions = "*"
+
+[package.dependencies]
+antlr4-python3-runtime = ">=4.9.0,<4.10.0"
+importlib-resources = {version = "*", markers = "python_version < \"3.9\""}
+omegaconf = ">=2.2,<3.0"
+packaging = "*"
+
+[[package]]
+name = "hydra-submitit-launcher"
+version = "1.2.0"
+description = "Submitit Launcher for Hydra apps"
+category = "main"
+optional = true
+python-versions = "*"
+
+[package.dependencies]
+hydra-core = ">=1.1.0.dev7"
+submitit = ">=1.3.3"
+
+[[package]]
+name = "hydra-zen"
+version = "0.7.1"
+description = "Configurable, reproducible, and scalable workflows in Python, via Hydra"
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[package.dependencies]
+hydra-core = ">=1.1.0"
+typing-extensions = ">=4.0.1"
+
+[package.extras]
+beartype = ["beartype (>=0.8.0)"]
+pydantic = ["pydantic (>=1.8.2)"]
+
+[[package]]
+name = "identify"
+version = "2.5.1"
+description = "File identification library for Python"
+category = "dev"
+optional = false
+python-versions = ">=3.7"
+
+[package.extras]
+license = ["ukkonen"]
+
+[[package]]
+name = "idna"
+version = "3.3"
+description = "Internationalized Domain Names in Applications (IDNA)"
+category = "main"
+optional = false
+python-versions = ">=3.5"
+
+[[package]]
+name = "imageio"
+version = "2.19.3"
+description = "Library for reading and writing a wide range of image, video, scientific, and volumetric data formats."
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.dependencies]
+numpy = "*"
+pillow = ">=8.3.2"
+
+[package.extras]
+all-plugins = ["astropy", "av", "imageio-ffmpeg", "opencv-python", "psutil", "tifffile"]
+all-plugins-pypy = ["av", "imageio-ffmpeg", "psutil", "tifffile"]
+build = ["wheel"]
+dev = ["black", "flake8", "fsspec[github]", "invoke", "pytest", "pytest-cov"]
+docs = ["numpydoc", "pydata-sphinx-theme", "sphinx"]
+ffmpeg = ["imageio-ffmpeg", "psutil"]
+fits = ["astropy"]
+full = ["astropy", "av", "black", "flake8", "fsspec[github]", "gdal", "imageio-ffmpeg", "invoke", "itk", "numpydoc", "opencv-python", "psutil", "pydata-sphinx-theme", "pytest", "pytest-cov", "sphinx", "tifffile", "wheel"]
+gdal = ["gdal"]
+itk = ["itk"]
+linting = ["black", "flake8"]
+opencv = ["opencv-python"]
+pyav = ["av"]
+test = ["fsspec[github]", "invoke", "pytest", "pytest-cov"]
+tifffile = ["tifffile"]
+
+[[package]]
+name = "imageio-ffmpeg"
+version = "0.4.7"
+description = "FFMPEG wrapper for Python"
+category = "main"
+optional = false
+python-versions = ">=3.5"
+
+[[package]]
+name = "importlib-metadata"
+version = "4.2.0"
+description = "Read metadata from Python packages"
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[package.dependencies]
+typing-extensions = {version = ">=3.6.4", markers = "python_version < \"3.8\""}
+zipp = ">=0.5"
+
+[package.extras]
+docs = ["jaraco.packaging (>=8.2)", "rst.linker (>=1.9)", "sphinx"]
+testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pep517", "pyfakefs", "pytest (>=4.6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.0.1)", "pytest-flake8", "pytest-mypy"]
+
+[[package]]
+name = "importlib-resources"
+version = "5.8.0"
+description = "Read resources from Python packages"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.dependencies]
+zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""}
+
+[package.extras]
+docs = ["jaraco.packaging (>=9)", "rst.linker (>=1.9)", "sphinx"]
+testing = ["pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.0.1)", "pytest-flake8", "pytest-mypy (>=0.9.1)"]
+
+[[package]]
+name = "iniconfig"
+version = "1.1.1"
+description = "iniconfig: brain-dead simple config-ini parsing"
+category = "dev"
+optional = false
+python-versions = "*"
+
+[[package]]
+name = "ipykernel"
+version = "6.16.2"
+description = "IPython Kernel for Jupyter"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.dependencies]
+appnope = {version = "*", markers = "platform_system == \"Darwin\""}
+debugpy = ">=1.0"
+ipython = ">=7.23.1"
+jupyter-client = ">=6.1.12"
+matplotlib-inline = ">=0.1"
+nest-asyncio = "*"
+packaging = "*"
+psutil = "*"
+pyzmq = ">=17"
+tornado = ">=6.1"
+traitlets = ">=5.1.0"
+
+[package.extras]
+docs = ["myst-parser", "pydata-sphinx-theme", "sphinx", "sphinxcontrib-github-alt"]
+test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0)", "pytest-cov", "pytest-timeout"]
+
+[[package]]
+name = "ipython"
+version = "7.34.0"
+description = "IPython: Productive Interactive Computing"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.dependencies]
+appnope = {version = "*", markers = "sys_platform == \"darwin\""}
+backcall = "*"
+colorama = {version = "*", markers = "sys_platform == \"win32\""}
+decorator = "*"
+jedi = ">=0.16"
+matplotlib-inline = "*"
+pexpect = {version = ">4.3", markers = "sys_platform != \"win32\""}
+pickleshare = "*"
+prompt-toolkit = ">=2.0.0,<3.0.0 || >3.0.0,<3.0.1 || >3.0.1,<3.1.0"
+pygments = "*"
+setuptools = ">=18.5"
+traitlets = ">=4.2"
+
+[package.extras]
+all = ["Sphinx (>=1.3)", "ipykernel", "ipyparallel", "ipywidgets", "nbconvert", "nbformat", "nose (>=0.10.1)", "notebook", "numpy (>=1.17)", "pygments", "qtconsole", "requests", "testpath"]
+doc = ["Sphinx (>=1.3)"]
+kernel = ["ipykernel"]
+nbconvert = ["nbconvert"]
+nbformat = ["nbformat"]
+notebook = ["ipywidgets", "notebook"]
+parallel = ["ipyparallel"]
+qtconsole = ["qtconsole"]
+test = ["ipykernel", "nbformat", "nose (>=0.10.1)", "numpy (>=1.17)", "pygments", "requests", "testpath"]
+
+[[package]]
+name = "ipython_genutils"
+version = "0.2.0"
+description = "Vestigial utilities from IPython"
+category = "main"
+optional = false
+python-versions = "*"
+
+[[package]]
+name = "ipywidgets"
+version = "8.0.6"
+description = "Jupyter interactive widgets"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.dependencies]
+ipykernel = ">=4.5.1"
+ipython = ">=6.1.0"
+jupyterlab-widgets = ">=3.0.7,<3.1.0"
+traitlets = ">=4.3.1"
+widgetsnbextension = ">=4.0.7,<4.1.0"
+
+[package.extras]
+test = ["ipykernel", "jsonschema", "pytest (>=3.6.0)", "pytest-cov", "pytz"]
+
+[[package]]
+name = "isort"
+version = "5.10.1"
+description = "A Python utility / library to sort Python imports."
+category = "dev"
+optional = false
+python-versions = ">=3.6.1,<4.0"
+
+[package.extras]
+colors = ["colorama (>=0.4.3,<0.5.0)"]
+pipfile_deprecated_finder = ["pipreqs", "requirementslib"]
+plugins = ["setuptools"]
+requirements_deprecated_finder = ["pip-api", "pipreqs"]
+
+[[package]]
+name = "itsdangerous"
+version = "2.1.2"
+description = "Safely pass data to untrusted environments and back."
+category = "main"
+optional = true
+python-versions = ">=3.7"
+
+[[package]]
+name = "jedi"
+version = "0.18.2"
+description = "An autocompletion tool for Python that can be used for text editors."
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[package.dependencies]
+parso = ">=0.8.0,<0.9.0"
+
+[package.extras]
+docs = ["Jinja2 (==2.11.3)", "MarkupSafe (==1.1.1)", "Pygments (==2.8.1)", "alabaster (==0.7.12)", "babel (==2.9.1)", "chardet (==4.0.0)", "commonmark (==0.8.1)", "docutils (==0.17.1)", "future (==0.18.2)", "idna (==2.10)", "imagesize (==1.2.0)", "mock (==1.0.1)", "packaging (==20.9)", "pyparsing (==2.4.7)", "pytz (==2021.1)", "readthedocs-sphinx-ext (==2.1.4)", "recommonmark (==0.5.0)", "requests (==2.25.1)", "six (==1.15.0)", "snowballstemmer (==2.1.0)", "sphinx (==1.8.5)", "sphinx-rtd-theme (==0.4.3)", "sphinxcontrib-serializinghtml (==1.1.4)", "sphinxcontrib-websupport (==1.2.4)", "urllib3 (==1.26.4)"]
+qa = ["flake8 (==3.8.3)", "mypy (==0.782)"]
+testing = ["Django (<3.1)", "attrs", "colorama", "docopt", "pytest (<7.0.0)"]
+
+[[package]]
+name = "Jinja2"
+version = "3.1.2"
+description = "A very fast and expressive template engine."
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.dependencies]
+MarkupSafe = ">=2.0"
+
+[package.extras]
+i18n = ["Babel (>=2.7)"]
+
+[[package]]
+name = "jmespath"
+version = "1.0.1"
+description = "JSON Matching Expressions"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[[package]]
+name = "joblib"
+version = "1.1.0"
+description = "Lightweight pipelining with Python functions"
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[[package]]
+name = "jsonschema"
+version = "4.17.3"
+description = "An implementation of JSON Schema validation for Python"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.dependencies]
+attrs = ">=17.4.0"
+importlib-metadata = {version = "*", markers = "python_version < \"3.8\""}
+importlib-resources = {version = ">=1.4.0", markers = "python_version < \"3.9\""}
+pkgutil-resolve-name = {version = ">=1.3.10", markers = "python_version < \"3.9\""}
+pyrsistent = ">=0.14.0,<0.17.0 || >0.17.0,<0.17.1 || >0.17.1,<0.17.2 || >0.17.2"
+typing-extensions = {version = "*", markers = "python_version < \"3.8\""}
+
+[package.extras]
+format = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3987", "uri-template", "webcolors (>=1.11)"]
+format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3986-validator (>0.1.0)", "uri-template", "webcolors (>=1.11)"]
+
+[[package]]
+name = "jupyter"
+version = "1.0.0"
+description = "Jupyter metapackage. Install all the Jupyter components in one go."
+category = "main"
+optional = false
+python-versions = "*"
+
+[package.dependencies]
+ipykernel = "*"
+ipywidgets = "*"
+jupyter-console = "*"
+nbconvert = "*"
+notebook = "*"
+qtconsole = "*"
+
+[[package]]
+name = "jupyter-client"
+version = "7.4.9"
+description = "Jupyter protocol implementation and client libraries"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.dependencies]
+entrypoints = "*"
+jupyter-core = ">=4.9.2"
+nest-asyncio = ">=1.5.4"
+python-dateutil = ">=2.8.2"
+pyzmq = ">=23.0"
+tornado = ">=6.2"
+traitlets = "*"
+
+[package.extras]
+doc = ["ipykernel", "myst-parser", "sphinx (>=1.3.6)", "sphinx-rtd-theme", "sphinxcontrib-github-alt"]
+test = ["codecov", "coverage", "ipykernel (>=6.12)", "ipython", "mypy", "pre-commit", "pytest", "pytest-asyncio (>=0.18)", "pytest-cov", "pytest-timeout"]
+
+[[package]]
+name = "jupyter-console"
+version = "6.6.3"
+description = "Jupyter terminal console"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.dependencies]
+ipykernel = ">=6.14"
+ipython = "*"
+jupyter-client = ">=7.0.0"
+jupyter-core = ">=4.12,<5.0.0 || >=5.1.0"
+prompt-toolkit = ">=3.0.30"
+pygments = "*"
+pyzmq = ">=17"
+traitlets = ">=5.4"
+
+[package.extras]
+test = ["flaky", "pexpect", "pytest"]
+
+[[package]]
+name = "jupyter-core"
+version = "4.12.0"
+description = "Jupyter core package. A base package on which Jupyter projects rely."
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.dependencies]
+pywin32 = {version = ">=1.0", markers = "sys_platform == \"win32\" and platform_python_implementation != \"PyPy\""}
+traitlets = "*"
+
+[package.extras]
+test = ["ipykernel", "pre-commit", "pytest", "pytest-cov", "pytest-timeout"]
+
+[[package]]
+name = "jupyter-server"
+version = "1.24.0"
+description = "The backend—i.e. core services, APIs, and REST endpoints—to Jupyter web applications."
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.dependencies]
+anyio = ">=3.1.0,<4"
+argon2-cffi = "*"
+jinja2 = "*"
+jupyter-client = ">=6.1.12"
+jupyter-core = ">=4.12,<5.0.0 || >=5.1.0"
+nbconvert = ">=6.4.4"
+nbformat = ">=5.2.0"
+packaging = "*"
+prometheus-client = "*"
+pywinpty = {version = "*", markers = "os_name == \"nt\""}
+pyzmq = ">=17"
+Send2Trash = "*"
+terminado = ">=0.8.3"
+tornado = ">=6.1.0"
+traitlets = ">=5.1"
+websocket-client = "*"
+
+[package.extras]
+test = ["coverage", "ipykernel", "pre-commit", "pytest (>=7.0)", "pytest-console-scripts", "pytest-cov", "pytest-mock", "pytest-timeout", "pytest-tornasync", "requests"]
+
+[[package]]
+name = "jupyterlab-pygments"
+version = "0.2.2"
+description = "Pygments theme using JupyterLab CSS variables"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[[package]]
+name = "jupyterlab-widgets"
+version = "3.0.7"
+description = "Jupyter interactive widgets for JupyterLab"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[[package]]
+name = "kiwisolver"
+version = "1.4.3"
+description = "A fast implementation of the Cassowary constraint solver"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.dependencies]
+typing-extensions = {version = "*", markers = "python_version < \"3.8\""}
+
+[[package]]
+name = "Mako"
+version = "1.2.3"
+description = "A super-fast templating language that borrows the best ideas from the existing templating languages."
+category = "main"
+optional = true
+python-versions = ">=3.7"
+
+[package.dependencies]
+importlib-metadata = {version = "*", markers = "python_version < \"3.8\""}
+MarkupSafe = ">=0.9.2"
+
+[package.extras]
+babel = ["Babel"]
+lingua = ["lingua"]
+testing = ["pytest"]
+
+[[package]]
+name = "markdown"
+version = "3.3.5"
+description = "Python implementation of Markdown."
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[package.extras]
+testing = ["coverage", "pyyaml"]
+
+[[package]]
+name = "MarkupSafe"
+version = "2.1.1"
+description = "Safely add untrusted strings to HTML/XML markup."
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[[package]]
+name = "matplotlib"
+version = "3.5.2"
+description = "Python plotting package"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.dependencies]
+cycler = ">=0.10"
+fonttools = ">=4.22.0"
+kiwisolver = ">=1.0.1"
+numpy = ">=1.17"
+packaging = ">=20.0"
+pillow = ">=6.2.0"
+pyparsing = ">=2.2.1"
+python-dateutil = ">=2.7"
+setuptools_scm = ">=4"
+
+[[package]]
+name = "matplotlib-inline"
+version = "0.1.6"
+description = "Inline Matplotlib backend for Jupyter"
+category = "main"
+optional = false
+python-versions = ">=3.5"
+
+[package.dependencies]
+traitlets = "*"
+
+[[package]]
+name = "mccabe"
+version = "0.6.1"
+description = "McCabe checker, plugin for flake8"
+category = "dev"
+optional = false
+python-versions = "*"
+
+[[package]]
+name = "mistune"
+version = "2.0.5"
+description = "A sane Markdown parser with useful plugins and renderers"
+category = "main"
+optional = false
+python-versions = "*"
+
+[[package]]
+name = "mlflow"
+version = "1.29.0"
+description = "MLflow: A Platform for ML Development and Productionization"
+category = "main"
+optional = true
+python-versions = ">=3.7"
+
+[package.dependencies]
+alembic = "<2"
+click = ">=7.0,<9"
+cloudpickle = "<3"
+databricks-cli = ">=0.8.7,<1"
+docker = ">=4.0.0,<7"
+entrypoints = "<1"
+Flask = "<3"
+gitpython = ">=2.1.0,<4"
+gunicorn = {version = "<21", markers = "platform_system != \"Windows\""}
+importlib-metadata = ">=3.7.0,<4.7.0 || >4.7.0,<5"
+numpy = "<2"
+packaging = "<22"
+pandas = "<2"
+prometheus-flask-exporter = "<1"
+protobuf = ">=3.12.0,<5"
+pytz = "<2023"
+pyyaml = ">=5.1,<7"
+querystring-parser = "<2"
+requests = ">=2.17.3,<3"
+scipy = "<2"
+sqlalchemy = ">=1.4.0,<2"
+sqlparse = ">=0.4.0,<1"
+waitress = {version = "<3", markers = "platform_system == \"Windows\""}
+
+[package.extras]
+aliyun-oss = ["aliyunstoreplugin"]
+extras = ["azureml-core (>=1.2.0)", "boto3", "google-cloud-storage (>=1.30.0)", "kubernetes", "mlserver (>=0.5.3)", "mlserver-mlflow (>=0.5.3)", "pyarrow", "pysftp", "scikit-learn", "virtualenv"]
+pipelines = ["Jinja2 (>=3.0)", "ipython (>=7.0)", "markdown (>=3.3)", "pandas-profiling (>=3.1)", "pyarrow (>=7.0)", "scikit-learn (>=1.0)", "shap (>=0.40)"]
+sqlserver = ["mlflow-dbstore"]
+
+[[package]]
+name = "motmetrics"
+version = "1.2.5"
+description = "Metrics for multiple object tracker benchmarking."
+category = "main"
+optional = false
+python-versions = "*"
+
+[package.dependencies]
+numpy = ">=1.12.1"
+pandas = ">=0.23.1"
+scipy = ">=0.19.0"
+xmltodict = ">=0.12.0"
+
+[[package]]
+name = "moviepy"
+version = "1.0.3"
+description = "Video editing with Python"
+category = "main"
+optional = false
+python-versions = "*"
+
+[package.dependencies]
+decorator = ">=4.0.2,<5.0"
+imageio = {version = ">=2.5,<3.0", markers = "python_version >= \"3.4\""}
+imageio_ffmpeg = {version = ">=0.2.0", markers = "python_version >= \"3.4\""}
+numpy = [
+ {version = ">=1.17.3", markers = "python_version != \"2.7\""},
+ {version = "*", markers = "python_version >= \"2.7\""},
+]
+proglog = "<=1.0.0"
+requests = ">=2.8.1,<3.0"
+tqdm = ">=4.11.2,<5.0"
+
+[package.extras]
+doc = ["Sphinx (>=1.5.2,<2.0)", "numpydoc (>=0.6.0,<1.0)", "pygame (>=1.9.3,<2.0)", "sphinx_rtd_theme (>=0.1.10b0,<1.0)"]
+optional = ["matplotlib (>=2.0.0,<3.0)", "opencv-python (>=3.0,<4.0)", "scikit-image (>=0.13.0,<1.0)", "scikit-learn", "scipy (>=0.19.0,<1.5)", "youtube_dl"]
+test = ["coverage (<5.0)", "coveralls (>=1.1,<2.0)", "pytest (>=3.0.0,<4.0)", "pytest-cov (>=2.5.1,<3.0)", "requests (>=2.8.1,<3.0)"]
+
+[[package]]
+name = "multidict"
+version = "6.0.2"
+description = "multidict implementation"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[[package]]
+name = "mypy-extensions"
+version = "0.4.3"
+description = "Experimental type system extensions for programs checked with the mypy typechecker."
+category = "dev"
+optional = false
+python-versions = "*"
+
+[[package]]
+name = "nbclassic"
+version = "0.5.5"
+description = "Jupyter Notebook as a Jupyter Server extension."
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.dependencies]
+argon2-cffi = "*"
+ipykernel = "*"
+ipython-genutils = "*"
+jinja2 = "*"
+jupyter-client = ">=6.1.1"
+jupyter-core = ">=4.6.1"
+jupyter-server = ">=1.8"
+nbconvert = ">=5"
+nbformat = "*"
+nest-asyncio = ">=1.5"
+notebook-shim = ">=0.1.0"
+prometheus-client = "*"
+pyzmq = ">=17"
+Send2Trash = ">=1.8.0"
+terminado = ">=0.8.3"
+tornado = ">=6.1"
+traitlets = ">=4.2.1"
+
+[package.extras]
+docs = ["myst-parser", "nbsphinx", "sphinx", "sphinx-rtd-theme", "sphinxcontrib-github-alt"]
+json-logging = ["json-logging"]
+test = ["coverage", "nbval", "pytest", "pytest-cov", "pytest-jupyter", "pytest-playwright", "pytest-tornasync", "requests", "requests-unixsocket", "testpath"]
+
+[[package]]
+name = "nbclient"
+version = "0.7.3"
+description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor."
+category = "main"
+optional = false
+python-versions = ">=3.7.0"
+
+[package.dependencies]
+jupyter-client = ">=6.1.12"
+jupyter-core = ">=4.12,<5.0.0 || >=5.1.0"
+nbformat = ">=5.1"
+traitlets = ">=5.3"
+
+[package.extras]
+dev = ["pre-commit"]
+docs = ["autodoc-traits", "mock", "moto", "myst-parser", "nbclient[test]", "sphinx (>=1.7)", "sphinx-book-theme", "sphinxcontrib-spelling"]
+test = ["flaky", "ipykernel", "ipython", "ipywidgets", "nbconvert (>=7.0.0)", "pytest (>=7.0)", "pytest-asyncio", "pytest-cov (>=4.0)", "testpath", "xmltodict"]
+
+[[package]]
+name = "nbconvert"
+version = "7.3.1"
+description = "Converting Jupyter Notebooks"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.dependencies]
+beautifulsoup4 = "*"
+bleach = "*"
+defusedxml = "*"
+importlib-metadata = {version = ">=3.6", markers = "python_version < \"3.10\""}
+jinja2 = ">=3.0"
+jupyter-core = ">=4.7"
+jupyterlab-pygments = "*"
+markupsafe = ">=2.0"
+mistune = ">=2.0.3,<3"
+nbclient = ">=0.5.0"
+nbformat = ">=5.1"
+packaging = "*"
+pandocfilters = ">=1.4.1"
+pygments = ">=2.4.1"
+tinycss2 = "*"
+traitlets = ">=5.0"
+
+[package.extras]
+all = ["nbconvert[docs,qtpdf,serve,test,webpdf]"]
+docs = ["ipykernel", "ipython", "myst-parser", "nbsphinx (>=0.2.12)", "pydata-sphinx-theme", "sphinx (==5.0.2)", "sphinxcontrib-spelling"]
+qtpdf = ["nbconvert[qtpng]"]
+qtpng = ["pyqtwebengine (>=5.15)"]
+serve = ["tornado (>=6.1)"]
+test = ["ipykernel", "ipywidgets (>=7)", "pre-commit", "pytest", "pytest-dependency"]
+webpdf = ["pyppeteer (>=1,<1.1)"]
+
+[[package]]
+name = "nbformat"
+version = "5.8.0"
+description = "The Jupyter Notebook format"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.dependencies]
+fastjsonschema = "*"
+importlib-metadata = {version = ">=3.6", markers = "python_version < \"3.8\""}
+jsonschema = ">=2.6"
+jupyter-core = "*"
+traitlets = ">=5.1"
+
+[package.extras]
+docs = ["myst-parser", "pydata-sphinx-theme", "sphinx", "sphinxcontrib-github-alt", "sphinxcontrib-spelling"]
+test = ["pep440", "pre-commit", "pytest", "testpath"]
+
+[[package]]
+name = "nest-asyncio"
+version = "1.5.6"
+description = "Patch asyncio to allow nested event loops"
+category = "main"
+optional = false
+python-versions = ">=3.5"
+
+[[package]]
+name = "nodeenv"
+version = "1.7.0"
+description = "Node.js virtual environment builder"
+category = "dev"
+optional = false
+python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*"
+
+[package.dependencies]
+setuptools = "*"
+
+[[package]]
+name = "notebook"
+version = "6.5.4"
+description = "A web-based notebook environment for interactive computing"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.dependencies]
+argon2-cffi = "*"
+ipykernel = "*"
+ipython-genutils = "*"
+jinja2 = "*"
+jupyter-client = ">=5.3.4"
+jupyter-core = ">=4.6.1"
+nbclassic = ">=0.4.7"
+nbconvert = ">=5"
+nbformat = "*"
+nest-asyncio = ">=1.5"
+prometheus-client = "*"
+pyzmq = ">=17"
+Send2Trash = ">=1.8.0"
+terminado = ">=0.8.3"
+tornado = ">=6.1"
+traitlets = ">=4.2.1"
+
+[package.extras]
+docs = ["myst-parser", "nbsphinx", "sphinx", "sphinx-rtd-theme", "sphinxcontrib-github-alt"]
+json-logging = ["json-logging"]
+test = ["coverage", "nbval", "pytest", "pytest-cov", "requests", "requests-unixsocket", "selenium (==4.1.5)", "testpath"]
+
+[[package]]
+name = "notebook-shim"
+version = "0.2.3"
+description = "A shim layer for notebook traits and config"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.dependencies]
+jupyter-server = ">=1.8,<3"
+
+[package.extras]
+test = ["pytest", "pytest-console-scripts", "pytest-jupyter", "pytest-tornasync"]
+
+[[package]]
+name = "numpy"
+version = "1.21.6"
+description = "NumPy is the fundamental package for array computing with Python."
+category = "main"
+optional = false
+python-versions = ">=3.7,<3.11"
+
+[[package]]
+name = "oauthlib"
+version = "3.2.0"
+description = "A generic, spec-compliant, thorough implementation of the OAuth request-signing logic"
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[package.extras]
+rsa = ["cryptography (>=3.0.0)"]
+signals = ["blinker (>=1.4.0)"]
+signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"]
+
+[[package]]
+name = "omegaconf"
+version = "2.2.2"
+description = "A flexible configuration library"
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[package.dependencies]
+antlr4-python3-runtime = ">=4.9.0,<4.10.0"
+PyYAML = ">=5.1.0"
+
+[[package]]
+name = "packaging"
+version = "21.3"
+description = "Core utilities for Python packages"
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[package.dependencies]
+pyparsing = ">=2.0.2,<3.0.5 || >3.0.5"
+
+[[package]]
+name = "pandas"
+version = "1.3.5"
+description = "Powerful data structures for data analysis, time series, and statistics"
+category = "main"
+optional = false
+python-versions = ">=3.7.1"
+
+[package.dependencies]
+numpy = [
+ {version = ">=1.17.3", markers = "platform_machine != \"aarch64\" and platform_machine != \"arm64\" and python_version < \"3.10\""},
+ {version = ">=1.19.2", markers = "platform_machine == \"aarch64\" and python_version < \"3.10\""},
+ {version = ">=1.20.0", markers = "platform_machine == \"arm64\" and python_version < \"3.10\""},
+]
+python-dateutil = ">=2.7.3"
+pytz = ">=2017.3"
+
+[package.extras]
+test = ["hypothesis (>=3.58)", "pytest (>=6.0)", "pytest-xdist"]
+
+[[package]]
+name = "pandocfilters"
+version = "1.5.0"
+description = "Utilities for writing pandoc filters in python"
+category = "main"
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
+
+[[package]]
+name = "parso"
+version = "0.8.3"
+description = "A Python Parser"
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[package.extras]
+qa = ["flake8 (==3.8.3)", "mypy (==0.782)"]
+testing = ["docopt", "pytest (<6.0.0)"]
+
+[[package]]
+name = "pathspec"
+version = "0.9.0"
+description = "Utility library for gitignore style pattern matching of file paths."
+category = "dev"
+optional = false
+python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7"
+
+[[package]]
+name = "pexpect"
+version = "4.8.0"
+description = "Pexpect allows easy control of interactive console applications."
+category = "main"
+optional = false
+python-versions = "*"
+
+[package.dependencies]
+ptyprocess = ">=0.5"
+
+[[package]]
+name = "pickleshare"
+version = "0.7.5"
+description = "Tiny 'shelve'-like database with concurrency support"
+category = "main"
+optional = false
+python-versions = "*"
+
+[[package]]
+name = "pillow"
+version = "9.0.1"
+description = "Python Imaging Library (Fork)"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[[package]]
+name = "pkgutil_resolve_name"
+version = "1.3.10"
+description = "Resolve a name to an object."
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[[package]]
+name = "platformdirs"
+version = "2.5.2"
+description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"."
+category = "dev"
+optional = false
+python-versions = ">=3.7"
+
+[package.extras]
+docs = ["furo (>=2021.7.5b38)", "proselint (>=0.10.2)", "sphinx (>=4)", "sphinx-autodoc-typehints (>=1.12)"]
+test = ["appdirs (==1.4.4)", "pytest (>=6)", "pytest-cov (>=2.7)", "pytest-mock (>=3.6)"]
+
+[[package]]
+name = "pluggy"
+version = "1.0.0"
+description = "plugin and hook calling mechanisms for python"
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[package.dependencies]
+importlib-metadata = {version = ">=0.12", markers = "python_version < \"3.8\""}
+
+[package.extras]
+dev = ["pre-commit", "tox"]
+testing = ["pytest", "pytest-benchmark"]
+
+[[package]]
+name = "pre-commit"
+version = "2.19.0"
+description = "A framework for managing and maintaining multi-language pre-commit hooks."
+category = "dev"
+optional = false
+python-versions = ">=3.7"
+
+[package.dependencies]
+cfgv = ">=2.0.0"
+identify = ">=1.0.0"
+importlib-metadata = {version = "*", markers = "python_version < \"3.8\""}
+nodeenv = ">=0.11.1"
+pyyaml = ">=5.1"
+toml = "*"
+virtualenv = ">=20.0.8"
+
+[[package]]
+name = "proglog"
+version = "0.1.10"
+description = "Log and progress bar manager for console, notebooks, web..."
+category = "main"
+optional = false
+python-versions = "*"
+
+[package.dependencies]
+tqdm = "*"
+
+[[package]]
+name = "prometheus-client"
+version = "0.15.0"
+description = "Python client for the Prometheus monitoring system."
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[package.extras]
+twisted = ["twisted"]
+
+[[package]]
+name = "prometheus-flask-exporter"
+version = "0.20.3"
+description = "Prometheus metrics exporter for Flask"
+category = "main"
+optional = true
+python-versions = "*"
+
+[package.dependencies]
+flask = "*"
+prometheus-client = "*"
+
+[[package]]
+name = "prompt-toolkit"
+version = "3.0.38"
+description = "Library for building powerful interactive command lines in Python"
+category = "main"
+optional = false
+python-versions = ">=3.7.0"
+
+[package.dependencies]
+wcwidth = "*"
+
+[[package]]
+name = "protobuf"
+version = "3.20.1"
+description = "Protocol Buffers"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[[package]]
+name = "psutil"
+version = "5.9.5"
+description = "Cross-platform lib for process and system monitoring in Python."
+category = "main"
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
+
+[package.extras]
+test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"]
+
+[[package]]
+name = "ptyprocess"
+version = "0.7.0"
+description = "Run a subprocess in a pseudo terminal"
+category = "main"
+optional = false
+python-versions = "*"
+
+[[package]]
+name = "py"
+version = "1.11.0"
+description = "library with cross-python path, ini-parsing, io, code, log facilities"
+category = "dev"
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
+
+[[package]]
+name = "pyamg"
+version = "4.2.3"
+description = "PyAMG: Algebraic Multigrid Solvers in Python"
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[package.dependencies]
+numpy = ">=1.7.0"
+scipy = ">=0.12.0"
+
+[[package]]
+name = "pyasn1"
+version = "0.4.8"
+description = "ASN.1 types and codecs"
+category = "main"
+optional = false
+python-versions = "*"
+
+[[package]]
+name = "pyasn1-modules"
+version = "0.2.8"
+description = "A collection of ASN.1-based protocols modules."
+category = "main"
+optional = false
+python-versions = "*"
+
+[package.dependencies]
+pyasn1 = ">=0.4.6,<0.5.0"
+
+[[package]]
+name = "pycodestyle"
+version = "2.8.0"
+description = "Python style guide checker"
+category = "dev"
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
+
+[[package]]
+name = "pycparser"
+version = "2.21"
+description = "C parser in Python"
+category = "main"
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
+
+[[package]]
+name = "pydeprecate"
+version = "0.3.2"
+description = "Deprecation tooling"
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[[package]]
+name = "pydocstyle"
+version = "6.1.1"
+description = "Python docstring style checker"
+category = "dev"
+optional = false
+python-versions = ">=3.6"
+
+[package.dependencies]
+snowballstemmer = "*"
+
+[package.extras]
+toml = ["toml"]
+
+[[package]]
+name = "pyflakes"
+version = "2.4.0"
+description = "passive checker of Python programs"
+category = "dev"
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
+
+[[package]]
+name = "Pygments"
+version = "2.15.1"
+description = "Pygments is a syntax highlighting package written in Python."
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.extras]
+plugins = ["importlib-metadata"]
+
+[[package]]
+name = "PyJWT"
+version = "2.6.0"
+description = "JSON Web Token implementation in Python"
+category = "main"
+optional = true
+python-versions = ">=3.7"
+
+[package.extras]
+crypto = ["cryptography (>=3.4.0)"]
+dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"]
+docs = ["sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"]
+tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"]
+
+[[package]]
+name = "pyparsing"
+version = "3.0.9"
+description = "pyparsing module - Classes and methods to define and execute parsing grammars"
+category = "main"
+optional = false
+python-versions = ">=3.6.8"
+
+[package.extras]
+diagrams = ["jinja2", "railroad-diagrams"]
+
+[[package]]
+name = "pyrsistent"
+version = "0.19.3"
+description = "Persistent/Functional/Immutable data structures"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[[package]]
+name = "pytest"
+version = "7.1.2"
+description = "pytest: simple powerful testing with Python"
+category = "dev"
+optional = false
+python-versions = ">=3.7"
+
+[package.dependencies]
+atomicwrites = {version = ">=1.0", markers = "sys_platform == \"win32\""}
+attrs = ">=19.2.0"
+colorama = {version = "*", markers = "sys_platform == \"win32\""}
+importlib-metadata = {version = ">=0.12", markers = "python_version < \"3.8\""}
+iniconfig = "*"
+packaging = "*"
+pluggy = ">=0.12,<2.0"
+py = ">=1.8.2"
+tomli = ">=1.0.0"
+
+[package.extras]
+testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "xmlschema"]
+
+[[package]]
+name = "python-dateutil"
+version = "2.8.2"
+description = "Extensions to the standard Python datetime module"
+category = "main"
+optional = false
+python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
+
+[package.dependencies]
+six = ">=1.5"
+
+[[package]]
+name = "pytorch-lightning"
+version = "1.6.4"
+description = "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers. Scale your models. Write less boilerplate."
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.dependencies]
+fsspec = {version = ">=2021.05.0,<2021.06.0 || >2021.06.0", extras = ["http"]}
+numpy = ">=1.17.2"
+packaging = ">=17.0"
+protobuf = "<=3.20.1"
+pyDeprecate = ">=0.3.1"
+PyYAML = ">=5.4"
+tensorboard = ">=2.2.0"
+torch = ">=1.8"
+torchmetrics = ">=0.4.1"
+tqdm = ">=4.57.0"
+typing-extensions = ">=4.0.0"
+
+[package.extras]
+all = ["cloudpickle (>=1.3)", "codecov (>=2.1)", "comet-ml (>=3.1.12)", "coverage (>=6.4)", "deepspeed", "fairscale (>=0.4.5)", "flake8 (>=3.9.2)", "gcsfs (>=2021.5.0)", "gym[classic_control] (>=0.17.0)", "hivemind (>=1.0.1)", "horovod (>=0.21.2,!=0.24.0)", "hydra-core (>=1.0.5)", "ipython[all]", "jsonargparse[signatures] (>=4.7.1)", "matplotlib (>3.1)", "mlflow (>=1.0.0)", "mypy (>=0.920)", "neptune-client (>=0.10.0)", "omegaconf (>=2.0.5)", "onnxruntime", "pandas", "pre-commit (>=1.0)", "pytest (>=6.0)", "pytest-forked", "pytest-rerunfailures (>=10.2)", "rich (>=10.2.2,!=10.15.*)", "scikit-learn (>0.22.1)", "test-tube (>=0.7.5)", "torchtext (>=0.9)", "torchvision (>=0.9)", "wandb (>=0.8.21)"]
+deepspeed = ["deepspeed"]
+dev = ["cloudpickle (>=1.3)", "codecov (>=2.1)", "comet-ml (>=3.1.12)", "coverage (>=6.4)", "flake8 (>=3.9.2)", "gcsfs (>=2021.5.0)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.7.1)", "matplotlib (>3.1)", "mlflow (>=1.0.0)", "mypy (>=0.920)", "neptune-client (>=0.10.0)", "omegaconf (>=2.0.5)", "onnxruntime", "pandas", "pre-commit (>=1.0)", "pytest (>=6.0)", "pytest-forked", "pytest-rerunfailures (>=10.2)", "rich (>=10.2.2,!=10.15.*)", "scikit-learn (>0.22.1)", "test-tube (>=0.7.5)", "torchtext (>=0.9)", "wandb (>=0.8.21)"]
+examples = ["gym[classic_control] (>=0.17.0)", "ipython[all]", "torchvision (>=0.9)"]
+extra = ["gcsfs (>=2021.5.0)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.7.1)", "matplotlib (>3.1)", "omegaconf (>=2.0.5)", "rich (>=10.2.2,!=10.15.*)", "torchtext (>=0.9)"]
+fairscale = ["fairscale (>=0.4.5)"]
+hivemind = ["hivemind (>=1.0.1)"]
+horovod = ["horovod (>=0.21.2,!=0.24.0)"]
+loggers = ["comet-ml (>=3.1.12)", "mlflow (>=1.0.0)", "neptune-client (>=0.10.0)", "test-tube (>=0.7.5)", "wandb (>=0.8.21)"]
+strategies = ["deepspeed", "fairscale (>=0.4.5)", "hivemind (>=1.0.1)", "horovod (>=0.21.2,!=0.24.0)"]
+test = ["cloudpickle (>=1.3)", "codecov (>=2.1)", "coverage (>=6.4)", "flake8 (>=3.9.2)", "mypy (>=0.920)", "onnxruntime", "pandas", "pre-commit (>=1.0)", "pytest (>=6.0)", "pytest-forked", "pytest-rerunfailures (>=10.2)", "scikit-learn (>0.22.1)"]
+
+[[package]]
+name = "pytz"
+version = "2022.4"
+description = "World timezone definitions, modern and historical"
+category = "main"
+optional = false
+python-versions = "*"
+
+[[package]]
+name = "pywin32"
+version = "304"
+description = "Python for Window Extensions"
+category = "main"
+optional = false
+python-versions = "*"
+
+[[package]]
+name = "pywinpty"
+version = "2.0.10"
+description = "Pseudo terminal support for Windows from Python."
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[[package]]
+name = "pyyaml"
+version = "5.4.1"
+description = "YAML parser and emitter for Python"
+category = "main"
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*"
+
+[[package]]
+name = "pyzmq"
+version = "25.0.2"
+description = "Python bindings for 0MQ"
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[package.dependencies]
+cffi = {version = "*", markers = "implementation_name == \"pypy\""}
+
+[[package]]
+name = "qtconsole"
+version = "5.4.2"
+description = "Jupyter Qt console"
+category = "main"
+optional = false
+python-versions = ">= 3.7"
+
+[package.dependencies]
+ipykernel = ">=4.1"
+ipython-genutils = "*"
+jupyter-client = ">=4.1"
+jupyter-core = "*"
+packaging = "*"
+pygments = "*"
+pyzmq = ">=17.1"
+qtpy = ">=2.0.1"
+traitlets = "<5.2.1 || >5.2.1,<5.2.2 || >5.2.2"
+
+[package.extras]
+doc = ["Sphinx (>=1.3)"]
+test = ["flaky", "pytest", "pytest-qt"]
+
+[[package]]
+name = "QtPy"
+version = "2.3.1"
+description = "Provides an abstraction layer on top of the various Qt bindings (PyQt5/6 and PySide2/6)."
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.dependencies]
+packaging = "*"
+
+[package.extras]
+test = ["pytest (>=6,!=7.0.0,!=7.0.1)", "pytest-cov (>=3.0.0)", "pytest-qt"]
+
+[[package]]
+name = "querystring-parser"
+version = "1.2.4"
+description = "QueryString parser for Python/Django that correctly handles nested dictionaries"
+category = "main"
+optional = true
+python-versions = "*"
+
+[package.dependencies]
+six = "*"
+
+[[package]]
+name = "regex"
+version = "2022.7.9"
+description = "Alternative regular expression module, to replace re."
+category = "main"
+optional = true
+python-versions = ">=3.6"
+
+[[package]]
+name = "requests"
+version = "2.28.1"
+description = "Python HTTP for Humans."
+category = "main"
+optional = false
+python-versions = ">=3.7, <4"
+
+[package.dependencies]
+certifi = ">=2017.4.17"
+charset-normalizer = ">=2,<3"
+idna = ">=2.5,<4"
+urllib3 = ">=1.21.1,<1.27"
+
+[package.extras]
+socks = ["PySocks (>=1.5.6,!=1.5.7)"]
+use_chardet_on_py3 = ["chardet (>=3.0.2,<6)"]
+
+[[package]]
+name = "requests-oauthlib"
+version = "1.3.1"
+description = "OAuthlib authentication support for Requests."
+category = "main"
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
+
+[package.dependencies]
+oauthlib = ">=3.0.0"
+requests = ">=2.0.0"
+
+[package.extras]
+rsa = ["oauthlib[signedtoken] (>=3.0.0)"]
+
+[[package]]
+name = "rsa"
+version = "4.7.2"
+description = "Pure-Python RSA implementation"
+category = "main"
+optional = false
+python-versions = ">=3.5, <4"
+
+[package.dependencies]
+pyasn1 = ">=0.1.3"
+
+[[package]]
+name = "s3transfer"
+version = "0.6.0"
+description = "An Amazon S3 Transfer Manager"
+category = "main"
+optional = false
+python-versions = ">= 3.7"
+
+[package.dependencies]
+botocore = ">=1.12.36,<2.0a.0"
+
+[package.extras]
+crt = ["botocore[crt] (>=1.20.29,<2.0a.0)"]
+
+[[package]]
+name = "scikit-learn"
+version = "1.0.2"
+description = "A set of python modules for machine learning and data mining"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.dependencies]
+joblib = ">=0.11"
+numpy = ">=1.14.6"
+scipy = ">=1.1.0"
+threadpoolctl = ">=2.0.0"
+
+[package.extras]
+benchmark = ["matplotlib (>=2.2.3)", "memory-profiler (>=0.57.0)", "pandas (>=0.25.0)"]
+docs = ["Pillow (>=7.1.2)", "matplotlib (>=2.2.3)", "memory-profiler (>=0.57.0)", "numpydoc (>=1.0.0)", "pandas (>=0.25.0)", "scikit-image (>=0.14.5)", "seaborn (>=0.9.0)", "sphinx (>=4.0.1)", "sphinx-gallery (>=0.7.0)", "sphinx-prompt (>=1.3.0)", "sphinxext-opengraph (>=0.4.2)"]
+examples = ["matplotlib (>=2.2.3)", "pandas (>=0.25.0)", "scikit-image (>=0.14.5)", "seaborn (>=0.9.0)"]
+tests = ["black (>=21.6b0)", "flake8 (>=3.8.2)", "matplotlib (>=2.2.3)", "mypy (>=0.770)", "pandas (>=0.25.0)", "pyamg (>=4.0.0)", "pytest (>=5.0.1)", "pytest-cov (>=2.9.0)", "scikit-image (>=0.14.5)"]
+
+[[package]]
+name = "scipy"
+version = "1.7.3"
+description = "SciPy: Scientific Library for Python"
+category = "main"
+optional = false
+python-versions = ">=3.7,<3.11"
+
+[package.dependencies]
+numpy = ">=1.16.5,<1.23.0"
+
+[[package]]
+name = "Send2Trash"
+version = "1.8.0"
+description = "Send file to trash natively under Mac OS X, Windows and Linux."
+category = "main"
+optional = false
+python-versions = "*"
+
+[package.extras]
+nativelib = ["pyobjc-framework-Cocoa", "pywin32"]
+objc = ["pyobjc-framework-Cocoa"]
+win32 = ["pywin32"]
+
+[[package]]
+name = "setuptools"
+version = "65.4.0"
+description = "Easily download, build, install, upgrade, and uninstall Python packages"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.extras]
+docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-notfound-page (==0.8.3)", "sphinx-reredirects", "sphinxcontrib-towncrier"]
+testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8 (<5)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mock", "pip (>=19.1)", "pip-run (>=8.8)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
+testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"]
+
+[[package]]
+name = "setuptools-scm"
+version = "7.0.4"
+description = "the blessed package to manage your versions by scm tags"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.dependencies]
+importlib-metadata = {version = "*", markers = "python_version < \"3.8\""}
+packaging = ">=20.0"
+setuptools = "*"
+tomli = ">=1.0.0"
+typing-extensions = "*"
+
+[package.extras]
+test = ["pytest (>=6.2)", "virtualenv (>20)"]
+toml = ["setuptools (>=42)"]
+
+[[package]]
+name = "six"
+version = "1.16.0"
+description = "Python 2 and 3 compatibility utilities"
+category = "main"
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"
+
+[[package]]
+name = "smmap"
+version = "5.0.0"
+description = "A pure Python implementation of a sliding window memory map manager"
+category = "main"
+optional = true
+python-versions = ">=3.6"
+
+[[package]]
+name = "sniffio"
+version = "1.3.0"
+description = "Sniff out which async library your code is running under"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[[package]]
+name = "snowballstemmer"
+version = "2.2.0"
+description = "This package provides 29 stemmers for 28 languages generated from Snowball algorithms."
+category = "dev"
+optional = false
+python-versions = "*"
+
+[[package]]
+name = "soupsieve"
+version = "2.4.1"
+description = "A modern CSS selector implementation for Beautiful Soup."
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[[package]]
+name = "SQLAlchemy"
+version = "1.4.42"
+description = "Database Abstraction Library"
+category = "main"
+optional = true
+python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7"
+
+[package.dependencies]
+greenlet = {version = "!=0.4.17", markers = "python_version >= \"3\" and (platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\")"}
+importlib-metadata = {version = "*", markers = "python_version < \"3.8\""}
+
+[package.extras]
+aiomysql = ["aiomysql", "greenlet (!=0.4.17)"]
+aiosqlite = ["aiosqlite", "greenlet (!=0.4.17)", "typing_extensions (!=3.10.0.1)"]
+asyncio = ["greenlet (!=0.4.17)"]
+asyncmy = ["asyncmy (>=0.2.3,!=0.2.4)", "greenlet (!=0.4.17)"]
+mariadb_connector = ["mariadb (>=1.0.1,!=1.1.2)"]
+mssql = ["pyodbc"]
+mssql_pymssql = ["pymssql"]
+mssql_pyodbc = ["pyodbc"]
+mypy = ["mypy (>=0.910)", "sqlalchemy2-stubs"]
+mysql = ["mysqlclient (>=1.4.0)", "mysqlclient (>=1.4.0,<2)"]
+mysql_connector = ["mysql-connector-python"]
+oracle = ["cx_oracle (>=7)", "cx_oracle (>=7,<8)"]
+postgresql = ["psycopg2 (>=2.7)"]
+postgresql_asyncpg = ["asyncpg", "greenlet (!=0.4.17)"]
+postgresql_pg8000 = ["pg8000 (>=1.16.6,!=1.29.0)"]
+postgresql_psycopg2binary = ["psycopg2-binary"]
+postgresql_psycopg2cffi = ["psycopg2cffi"]
+pymysql = ["pymysql", "pymysql (<1)"]
+sqlcipher = ["sqlcipher3_binary"]
+
+[[package]]
+name = "sqlparse"
+version = "0.4.3"
+description = "A non-validating SQL parser."
+category = "main"
+optional = true
+python-versions = ">=3.5"
+
+[[package]]
+name = "submitit"
+version = "1.4.2"
+description = "\"Python 3.6+ toolbox for submitting jobs to Slurm"
+category = "main"
+optional = true
+python-versions = ">=3.6"
+
+[package.dependencies]
+cloudpickle = ">=1.2.1"
+typing_extensions = ">=3.7.4.2"
+
+[package.extras]
+dev = ["black (==22.3.0)", "coverage[toml] (>=5.1)", "flit (>=3.5.1)", "isort (==5.5.3)", "mypy (>=0.782)", "pre-commit (>=1.15.2)", "pylint (>=2.8.0)", "pytest (>=4.3.0)", "pytest-asyncio (>=0.15.0)", "pytest-cov (>=2.6.1)", "types-pkg_resources (>=0.1.2)"]
+
+[[package]]
+name = "tabulate"
+version = "0.9.0"
+description = "Pretty-print tabular data"
+category = "main"
+optional = true
+python-versions = ">=3.7"
+
+[package.extras]
+widechars = ["wcwidth"]
+
+[[package]]
+name = "tensorboard"
+version = "2.9.0"
+description = "TensorBoard lets you watch Tensors Flow"
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[package.dependencies]
+absl-py = ">=0.4"
+google-auth = ">=1.6.3,<3"
+google-auth-oauthlib = ">=0.4.1,<0.5"
+grpcio = ">=1.24.3"
+markdown = ">=2.6.8"
+numpy = ">=1.12.0"
+protobuf = ">=3.9.2"
+requests = ">=2.21.0,<3"
+setuptools = ">=41.0.0"
+tensorboard-data-server = ">=0.6.0,<0.7.0"
+tensorboard-plugin-wit = ">=1.6.0"
+werkzeug = ">=1.0.1"
+wheel = ">=0.26"
+
+[[package]]
+name = "tensorboard-data-server"
+version = "0.6.1"
+description = "Fast data loading for TensorBoard"
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[[package]]
+name = "tensorboard-plugin-wit"
+version = "1.8.1"
+description = "What-If Tool TensorBoard plugin."
+category = "main"
+optional = false
+python-versions = "*"
+
+[[package]]
+name = "terminado"
+version = "0.17.1"
+description = "Tornado websocket backend for the Xterm.js Javascript terminal emulator library."
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.dependencies]
+ptyprocess = {version = "*", markers = "os_name != \"nt\""}
+pywinpty = {version = ">=1.1.0", markers = "os_name == \"nt\""}
+tornado = ">=6.1.0"
+
+[package.extras]
+docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"]
+test = ["pre-commit", "pytest (>=7.0)", "pytest-timeout"]
+
+[[package]]
+name = "testfixtures"
+version = "6.18.5"
+description = "A collection of helpers and mock objects for unit tests and doc tests."
+category = "dev"
+optional = false
+python-versions = "*"
+
+[package.extras]
+build = ["setuptools-git", "twine", "wheel"]
+docs = ["django", "django (<2)", "mock", "sphinx", "sybil", "twisted", "zope.component"]
+test = ["django", "django (<2)", "mock", "pytest (>=3.6)", "pytest-cov", "pytest-django", "sybil", "twisted", "zope.component"]
+
+[[package]]
+name = "threadpoolctl"
+version = "3.1.0"
+description = "threadpoolctl"
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[[package]]
+name = "timm"
+version = "0.6.7"
+description = "(Unofficial) PyTorch Image Models"
+category = "main"
+optional = true
+python-versions = ">=3.6"
+
+[package.dependencies]
+torch = ">=1.4"
+torchvision = "*"
+
+[[package]]
+name = "tinycss2"
+version = "1.2.1"
+description = "A tiny CSS parser"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.dependencies]
+webencodings = ">=0.4"
+
+[package.extras]
+doc = ["sphinx", "sphinx_rtd_theme"]
+test = ["flake8", "isort", "pytest"]
+
+[[package]]
+name = "toml"
+version = "0.10.2"
+description = "Python Library for Tom's Obvious, Minimal Language"
+category = "dev"
+optional = false
+python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
+
+[[package]]
+name = "tomli"
+version = "2.0.1"
+description = "A lil' TOML parser"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[[package]]
+name = "torch"
+version = "1.12.1"
+description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
+category = "main"
+optional = false
+python-versions = ">=3.7.0"
+
+[package.dependencies]
+typing-extensions = "*"
+
+[[package]]
+name = "torchmetrics"
+version = "0.8.2"
+description = "PyTorch native Metrics"
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[package.dependencies]
+numpy = ">=1.17.2"
+packaging = "*"
+pyDeprecate = ">=0.3.0,<0.4.0"
+torch = ">=1.3.1"
+
+[package.extras]
+all = ["bert-score (==0.3.10)", "check-manifest", "cloudpickle (>=1.3)", "codecov (>=2.1)", "coverage (>5.2)", "docutils (>=0.16)", "fast-bss-eval (>=0.1.0)", "fire", "jiwer (>=2.3.0)", "lpips", "mir-eval (>=0.6)", "mypy (>=0.790)", "myst-parser", "nbsphinx (>=0.8)", "nltk (>=3.6)", "pandoc (>=1.0)", "pesq (>=0.0.3)", "phmdoctest (>=1.1.1)", "pre-commit (>=1.0)", "psutil", "pypesq", "pystoi", "pytest (>=6.0.0,<7.0.0)", "pytest-cov (>2.10)", "pytest-doctestplus (>=0.9.0)", "pytorch-lightning (>=1.5)", "pytorch-msssim", "regex (>=2021.9.24)", "requests", "rouge-score (>=0.0.4)", "sacrebleu (>=2.0.0)", "scikit-image (>0.17.1)", "scikit-learn (>=0.24)", "scipy", "sphinx (>=4.0)", "sphinx-autodoc-typehints (>=1.0)", "sphinx-copybutton (>=0.3)", "sphinx-paramlinks (>=0.5.1)", "sphinx-togglebutton (>=0.2)", "sphinxcontrib-fulltoc (>=1.0)", "sphinxcontrib-mockautodoc", "torch-complex", "torch-fidelity", "torchvision", "torchvision (>=0.8)", "tqdm (>=4.41.0)", "transformers (>=4.0)", "twine (>=3.2)"]
+audio = ["pesq (>=0.0.3)", "pystoi"]
+detection = ["torchvision (>=0.8)"]
+docs = ["docutils (>=0.16)", "myst-parser", "nbsphinx (>=0.8)", "pandoc (>=1.0)", "sphinx (>=4.0)", "sphinx-autodoc-typehints (>=1.0)", "sphinx-copybutton (>=0.3)", "sphinx-paramlinks (>=0.5.1)", "sphinx-togglebutton (>=0.2)", "sphinxcontrib-fulltoc (>=1.0)", "sphinxcontrib-mockautodoc"]
+image = ["lpips", "scipy", "torch-fidelity", "torchvision"]
+integrate = ["pytorch-lightning (>=1.5)"]
+test = ["bert-score (==0.3.10)", "check-manifest", "cloudpickle (>=1.3)", "codecov (>=2.1)", "coverage (>5.2)", "fast-bss-eval (>=0.1.0)", "fire", "jiwer (>=2.3.0)", "mir-eval (>=0.6)", "mypy (>=0.790)", "phmdoctest (>=1.1.1)", "pre-commit (>=1.0)", "psutil", "pypesq", "pytest (>=6.0.0,<7.0.0)", "pytest-cov (>2.10)", "pytest-doctestplus (>=0.9.0)", "pytorch-msssim", "requests", "rouge-score (>=0.0.4)", "sacrebleu (>=2.0.0)", "scikit-image (>0.17.1)", "scikit-learn (>=0.24)", "torch-complex", "transformers (>=4.0)", "twine (>=3.2)"]
+text = ["nltk (>=3.6)", "regex (>=2021.9.24)", "tqdm (>=4.41.0)"]
+
+[[package]]
+name = "torchtyping"
+version = "0.1.4"
+description = "Runtime type annotations for the shape, dtype etc. of PyTorch Tensors."
+category = "main"
+optional = false
+python-versions = ">=3.7.0"
+
+[package.dependencies]
+torch = ">=1.7.0"
+typeguard = ">=2.11.1"
+
+[[package]]
+name = "torchvision"
+version = "0.13.1"
+description = "image and video datasets and models for torch deep learning"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.dependencies]
+numpy = "*"
+pillow = ">=5.3.0,<8.3.0 || >=8.4.0"
+requests = "*"
+torch = "1.12.1"
+typing-extensions = "*"
+
+[package.extras]
+scipy = ["scipy"]
+
+[[package]]
+name = "tornado"
+version = "6.2"
+description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed."
+category = "main"
+optional = false
+python-versions = ">= 3.7"
+
+[[package]]
+name = "tqdm"
+version = "4.64.0"
+description = "Fast, Extensible Progress Meter"
+category = "main"
+optional = false
+python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7"
+
+[package.dependencies]
+colorama = {version = "*", markers = "platform_system == \"Windows\""}
+
+[package.extras]
+dev = ["py-make (>=0.1.0)", "twine", "wheel"]
+notebook = ["ipywidgets (>=6)"]
+slack = ["slack-sdk"]
+telegram = ["requests"]
+
+[[package]]
+name = "traitlets"
+version = "5.9.0"
+description = "Traitlets Python configuration system"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.extras]
+docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"]
+test = ["argcomplete (>=2.0)", "pre-commit", "pytest", "pytest-mock"]
+
+[[package]]
+name = "typed-ast"
+version = "1.5.4"
+description = "a fork of Python 2 and 3 ast modules with type comment support"
+category = "dev"
+optional = false
+python-versions = ">=3.6"
+
+[[package]]
+name = "typeguard"
+version = "2.13.3"
+description = "Run-time type checker for Python"
+category = "main"
+optional = false
+python-versions = ">=3.5.3"
+
+[package.extras]
+doc = ["sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"]
+test = ["mypy", "pytest", "typing-extensions"]
+
+[[package]]
+name = "typing-extensions"
+version = "4.3.0"
+description = "Backported and Experimental Type Hints for Python 3.7+"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[[package]]
+name = "urllib3"
+version = "1.26.9"
+description = "HTTP library with thread-safe connection pooling, file post, and more."
+category = "main"
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, <4"
+
+[package.extras]
+brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"]
+secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)"]
+socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"]
+
+[[package]]
+name = "virtualenv"
+version = "20.15.1"
+description = "Virtual Python Environment builder"
+category = "dev"
+optional = false
+python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7"
+
+[package.dependencies]
+distlib = ">=0.3.1,<1"
+filelock = ">=3.2,<4"
+importlib-metadata = {version = ">=0.12", markers = "python_version < \"3.8\""}
+platformdirs = ">=2,<3"
+six = ">=1.9.0,<2"
+
+[package.extras]
+docs = ["proselint (>=0.10.2)", "sphinx (>=3)", "sphinx-argparse (>=0.2.5)", "sphinx-rtd-theme (>=0.4.3)", "towncrier (>=21.3)"]
+testing = ["coverage (>=4)", "coverage-enable-subprocess (>=1)", "flaky (>=3)", "packaging (>=20.0)", "pytest (>=4)", "pytest-env (>=0.6.2)", "pytest-freezegun (>=0.4.1)", "pytest-mock (>=2)", "pytest-randomly (>=1)", "pytest-timeout (>=1)"]
+
+[[package]]
+name = "waitress"
+version = "2.1.2"
+description = "Waitress WSGI server"
+category = "main"
+optional = true
+python-versions = ">=3.7.0"
+
+[package.extras]
+docs = ["Sphinx (>=1.8.1)", "docutils", "pylons-sphinx-themes (>=1.0.9)"]
+testing = ["coverage (>=5.0)", "pytest", "pytest-cover"]
+
+[[package]]
+name = "wcwidth"
+version = "0.2.5"
+description = "Measures the displayed width of unicode strings in a terminal"
+category = "main"
+optional = false
+python-versions = "*"
+
+[[package]]
+name = "webdataset"
+version = "0.1.103"
+description = "Record sequential storage for deep learning."
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[package.dependencies]
+braceexpand = "*"
+numpy = "*"
+pyyaml = "*"
+
+[[package]]
+name = "webencodings"
+version = "0.5.1"
+description = "Character encoding aliases for legacy web content"
+category = "main"
+optional = false
+python-versions = "*"
+
+[[package]]
+name = "websocket-client"
+version = "1.4.1"
+description = "WebSocket client for Python with low level API options"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.extras]
+docs = ["Sphinx (>=3.4)", "sphinx-rtd-theme (>=0.5)"]
+optional = ["python-socks", "wsaccel"]
+test = ["websockets"]
+
+[[package]]
+name = "werkzeug"
+version = "2.1.2"
+description = "The comprehensive WSGI web application library."
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.extras]
+watchdog = ["watchdog"]
+
+[[package]]
+name = "wheel"
+version = "0.37.1"
+description = "A built-package format for Python"
+category = "main"
+optional = false
+python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7"
+
+[package.extras]
+test = ["pytest (>=3.0.0)", "pytest-cov"]
+
+[[package]]
+name = "widgetsnbextension"
+version = "4.0.7"
+description = "Jupyter interactive widgets for Jupyter Notebook"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[[package]]
+name = "xmltodict"
+version = "0.13.0"
+description = "Makes working with XML feel like you are working with JSON"
+category = "main"
+optional = false
+python-versions = ">=3.4"
+
+[[package]]
+name = "yarl"
+version = "1.7.2"
+description = "Yet another URL library"
+category = "main"
+optional = false
+python-versions = ">=3.6"
+
+[package.dependencies]
+idna = ">=2.0"
+multidict = ">=4.0"
+typing-extensions = {version = ">=3.7.4", markers = "python_version < \"3.8\""}
+
+[[package]]
+name = "zipp"
+version = "3.8.0"
+description = "Backport of pathlib-compatible object wrapper for zip files"
+category = "main"
+optional = false
+python-versions = ">=3.7"
+
+[package.extras]
+docs = ["jaraco.packaging (>=9)", "rst.linker (>=1.9)", "sphinx"]
+testing = ["func-timeout", "jaraco.itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.0.1)", "pytest-flake8", "pytest-mypy (>=0.9.1)"]
+
+[extras]
+clip = ["ftfy", "regex"]
+mlflow = ["mlflow"]
+submitit = ["hydra-submitit-launcher"]
+timm = ["timm"]
+
+[metadata]
+lock-version = "1.1"
+python-versions = ">=3.7.1,<3.9"
+content-hash = "f2bb550a2368d6220e6b4b8aff0d54c7a2b07f732737a2bcf9aa69f2a34a5c2c"
+
+[metadata.files]
+absl-py = [
+ {file = "absl-py-1.1.0.tar.gz", hash = "sha256:3aa39f898329c2156ff525dfa69ce709e42d77aab18bf4917719d6f260aa6a08"},
+ {file = "absl_py-1.1.0-py3-none-any.whl", hash = "sha256:db97287655e30336938f8058d2c81ed2be6af1d9b6ebbcd8df1080a6c7fcd24e"},
+]
+aiohttp = [
+ {file = "aiohttp-3.8.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:1ed0b6477896559f17b9eaeb6d38e07f7f9ffe40b9f0f9627ae8b9926ae260a8"},
+ {file = "aiohttp-3.8.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7dadf3c307b31e0e61689cbf9e06be7a867c563d5a63ce9dca578f956609abf8"},
+ {file = "aiohttp-3.8.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a79004bb58748f31ae1cbe9fa891054baaa46fb106c2dc7af9f8e3304dc30316"},
+ {file = "aiohttp-3.8.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:12de6add4038df8f72fac606dff775791a60f113a725c960f2bab01d8b8e6b15"},
+ {file = "aiohttp-3.8.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6f0d5f33feb5f69ddd57a4a4bd3d56c719a141080b445cbf18f238973c5c9923"},
+ {file = "aiohttp-3.8.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:eaba923151d9deea315be1f3e2b31cc39a6d1d2f682f942905951f4e40200922"},
+ {file = "aiohttp-3.8.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:099ebd2c37ac74cce10a3527d2b49af80243e2a4fa39e7bce41617fbc35fa3c1"},
+ {file = "aiohttp-3.8.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2e5d962cf7e1d426aa0e528a7e198658cdc8aa4fe87f781d039ad75dcd52c516"},
+ {file = "aiohttp-3.8.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:fa0ffcace9b3aa34d205d8130f7873fcfefcb6a4dd3dd705b0dab69af6712642"},
+ {file = "aiohttp-3.8.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:61bfc23df345d8c9716d03717c2ed5e27374e0fe6f659ea64edcd27b4b044cf7"},
+ {file = "aiohttp-3.8.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:31560d268ff62143e92423ef183680b9829b1b482c011713ae941997921eebc8"},
+ {file = "aiohttp-3.8.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:01d7bdb774a9acc838e6b8f1d114f45303841b89b95984cbb7d80ea41172a9e3"},
+ {file = "aiohttp-3.8.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:97ef77eb6b044134c0b3a96e16abcb05ecce892965a2124c566af0fd60f717e2"},
+ {file = "aiohttp-3.8.1-cp310-cp310-win32.whl", hash = "sha256:c2aef4703f1f2ddc6df17519885dbfa3514929149d3ff900b73f45998f2532fa"},
+ {file = "aiohttp-3.8.1-cp310-cp310-win_amd64.whl", hash = "sha256:713ac174a629d39b7c6a3aa757b337599798da4c1157114a314e4e391cd28e32"},
+ {file = "aiohttp-3.8.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:473d93d4450880fe278696549f2e7aed8cd23708c3c1997981464475f32137db"},
+ {file = "aiohttp-3.8.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:99b5eeae8e019e7aad8af8bb314fb908dd2e028b3cdaad87ec05095394cce632"},
+ {file = "aiohttp-3.8.1-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3af642b43ce56c24d063325dd2cf20ee012d2b9ba4c3c008755a301aaea720ad"},
+ {file = "aiohttp-3.8.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c3630c3ef435c0a7c549ba170a0633a56e92629aeed0e707fec832dee313fb7a"},
+ {file = "aiohttp-3.8.1-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:4a4a4e30bf1edcad13fb0804300557aedd07a92cabc74382fdd0ba6ca2661091"},
+ {file = "aiohttp-3.8.1-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:6f8b01295e26c68b3a1b90efb7a89029110d3a4139270b24fda961893216c440"},
+ {file = "aiohttp-3.8.1-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:a25fa703a527158aaf10dafd956f7d42ac6d30ec80e9a70846253dd13e2f067b"},
+ {file = "aiohttp-3.8.1-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:5bfde62d1d2641a1f5173b8c8c2d96ceb4854f54a44c23102e2ccc7e02f003ec"},
+ {file = "aiohttp-3.8.1-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:51467000f3647d519272392f484126aa716f747859794ac9924a7aafa86cd411"},
+ {file = "aiohttp-3.8.1-cp36-cp36m-musllinux_1_1_s390x.whl", hash = "sha256:03a6d5349c9ee8f79ab3ff3694d6ce1cfc3ced1c9d36200cb8f08ba06bd3b782"},
+ {file = "aiohttp-3.8.1-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:102e487eeb82afac440581e5d7f8f44560b36cf0bdd11abc51a46c1cd88914d4"},
+ {file = "aiohttp-3.8.1-cp36-cp36m-win32.whl", hash = "sha256:4aed991a28ea3ce320dc8ce655875e1e00a11bdd29fe9444dd4f88c30d558602"},
+ {file = "aiohttp-3.8.1-cp36-cp36m-win_amd64.whl", hash = "sha256:b0e20cddbd676ab8a64c774fefa0ad787cc506afd844de95da56060348021e96"},
+ {file = "aiohttp-3.8.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:37951ad2f4a6df6506750a23f7cbabad24c73c65f23f72e95897bb2cecbae676"},
+ {file = "aiohttp-3.8.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c23b1ad869653bc818e972b7a3a79852d0e494e9ab7e1a701a3decc49c20d51"},
+ {file = "aiohttp-3.8.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:15b09b06dae900777833fe7fc4b4aa426556ce95847a3e8d7548e2d19e34edb8"},
+ {file = "aiohttp-3.8.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:477c3ea0ba410b2b56b7efb072c36fa91b1e6fc331761798fa3f28bb224830dd"},
+ {file = "aiohttp-3.8.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:2f2f69dca064926e79997f45b2f34e202b320fd3782f17a91941f7eb85502ee2"},
+ {file = "aiohttp-3.8.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ef9612483cb35171d51d9173647eed5d0069eaa2ee812793a75373447d487aa4"},
+ {file = "aiohttp-3.8.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:6d69f36d445c45cda7b3b26afef2fc34ef5ac0cdc75584a87ef307ee3c8c6d00"},
+ {file = "aiohttp-3.8.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:55c3d1072704d27401c92339144d199d9de7b52627f724a949fc7d5fc56d8b93"},
+ {file = "aiohttp-3.8.1-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:b9d00268fcb9f66fbcc7cd9fe423741d90c75ee029a1d15c09b22d23253c0a44"},
+ {file = "aiohttp-3.8.1-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:07b05cd3305e8a73112103c834e91cd27ce5b4bd07850c4b4dbd1877d3f45be7"},
+ {file = "aiohttp-3.8.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:c34dc4958b232ef6188c4318cb7b2c2d80521c9a56c52449f8f93ab7bc2a8a1c"},
+ {file = "aiohttp-3.8.1-cp37-cp37m-win32.whl", hash = "sha256:d2f9b69293c33aaa53d923032fe227feac867f81682f002ce33ffae978f0a9a9"},
+ {file = "aiohttp-3.8.1-cp37-cp37m-win_amd64.whl", hash = "sha256:6ae828d3a003f03ae31915c31fa684b9890ea44c9c989056fea96e3d12a9fa17"},
+ {file = "aiohttp-3.8.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:0c7ebbbde809ff4e970824b2b6cb7e4222be6b95a296e46c03cf050878fc1785"},
+ {file = "aiohttp-3.8.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8b7ef7cbd4fec9a1e811a5de813311ed4f7ac7d93e0fda233c9b3e1428f7dd7b"},
+ {file = "aiohttp-3.8.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c3d6a4d0619e09dcd61021debf7059955c2004fa29f48788a3dfaf9c9901a7cd"},
+ {file = "aiohttp-3.8.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:718626a174e7e467f0558954f94af117b7d4695d48eb980146016afa4b580b2e"},
+ {file = "aiohttp-3.8.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:589c72667a5febd36f1315aa6e5f56dd4aa4862df295cb51c769d16142ddd7cd"},
+ {file = "aiohttp-3.8.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2ed076098b171573161eb146afcb9129b5ff63308960aeca4b676d9d3c35e700"},
+ {file = "aiohttp-3.8.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:086f92daf51a032d062ec5f58af5ca6a44d082c35299c96376a41cbb33034675"},
+ {file = "aiohttp-3.8.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:11691cf4dc5b94236ccc609b70fec991234e7ef8d4c02dd0c9668d1e486f5abf"},
+ {file = "aiohttp-3.8.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:31d1e1c0dbf19ebccbfd62eff461518dcb1e307b195e93bba60c965a4dcf1ba0"},
+ {file = "aiohttp-3.8.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:11a67c0d562e07067c4e86bffc1553f2cf5b664d6111c894671b2b8712f3aba5"},
+ {file = "aiohttp-3.8.1-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:bb01ba6b0d3f6c68b89fce7305080145d4877ad3acaed424bae4d4ee75faa950"},
+ {file = "aiohttp-3.8.1-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:44db35a9e15d6fe5c40d74952e803b1d96e964f683b5a78c3cc64eb177878155"},
+ {file = "aiohttp-3.8.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:844a9b460871ee0a0b0b68a64890dae9c415e513db0f4a7e3cab41a0f2fedf33"},
+ {file = "aiohttp-3.8.1-cp38-cp38-win32.whl", hash = "sha256:7d08744e9bae2ca9c382581f7dce1273fe3c9bae94ff572c3626e8da5b193c6a"},
+ {file = "aiohttp-3.8.1-cp38-cp38-win_amd64.whl", hash = "sha256:04d48b8ce6ab3cf2097b1855e1505181bdd05586ca275f2505514a6e274e8e75"},
+ {file = "aiohttp-3.8.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:f5315a2eb0239185af1bddb1abf472d877fede3cc8d143c6cddad37678293237"},
+ {file = "aiohttp-3.8.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a996d01ca39b8dfe77440f3cd600825d05841088fd6bc0144cc6c2ec14cc5f74"},
+ {file = "aiohttp-3.8.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:13487abd2f761d4be7c8ff9080de2671e53fff69711d46de703c310c4c9317ca"},
+ {file = "aiohttp-3.8.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ea302f34477fda3f85560a06d9ebdc7fa41e82420e892fc50b577e35fc6a50b2"},
+ {file = "aiohttp-3.8.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a2f635ce61a89c5732537a7896b6319a8fcfa23ba09bec36e1b1ac0ab31270d2"},
+ {file = "aiohttp-3.8.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e999f2d0e12eea01caeecb17b653f3713d758f6dcc770417cf29ef08d3931421"},
+ {file = "aiohttp-3.8.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:0770e2806a30e744b4e21c9d73b7bee18a1cfa3c47991ee2e5a65b887c49d5cf"},
+ {file = "aiohttp-3.8.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:d15367ce87c8e9e09b0f989bfd72dc641bcd04ba091c68cd305312d00962addd"},
+ {file = "aiohttp-3.8.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:6c7cefb4b0640703eb1069835c02486669312bf2f12b48a748e0a7756d0de33d"},
+ {file = "aiohttp-3.8.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:71927042ed6365a09a98a6377501af5c9f0a4d38083652bcd2281a06a5976724"},
+ {file = "aiohttp-3.8.1-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:28d490af82bc6b7ce53ff31337a18a10498303fe66f701ab65ef27e143c3b0ef"},
+ {file = "aiohttp-3.8.1-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:b6613280ccedf24354406caf785db748bebbddcf31408b20c0b48cb86af76866"},
+ {file = "aiohttp-3.8.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:81e3d8c34c623ca4e36c46524a3530e99c0bc95ed068fd6e9b55cb721d408fb2"},
+ {file = "aiohttp-3.8.1-cp39-cp39-win32.whl", hash = "sha256:7187a76598bdb895af0adbd2fb7474d7f6025d170bc0a1130242da817ce9e7d1"},
+ {file = "aiohttp-3.8.1-cp39-cp39-win_amd64.whl", hash = "sha256:1c182cb873bc91b411e184dab7a2b664d4fea2743df0e4d57402f7f3fa644bac"},
+ {file = "aiohttp-3.8.1.tar.gz", hash = "sha256:fc5471e1a54de15ef71c1bc6ebe80d4dc681ea600e68bfd1cbce40427f0b7578"},
+]
+aiosignal = [
+ {file = "aiosignal-1.2.0-py3-none-any.whl", hash = "sha256:26e62109036cd181df6e6ad646f91f0dcfd05fe16d0cb924138ff2ab75d64e3a"},
+ {file = "aiosignal-1.2.0.tar.gz", hash = "sha256:78ed67db6c7b7ced4f98e495e572106d5c432a93e1ddd1bf475e1dc05f5b7df2"},
+]
+alembic = [
+ {file = "alembic-1.8.1-py3-none-any.whl", hash = "sha256:0a024d7f2de88d738d7395ff866997314c837be6104e90c5724350313dee4da4"},
+ {file = "alembic-1.8.1.tar.gz", hash = "sha256:cd0b5e45b14b706426b833f06369b9a6d5ee03f826ec3238723ce8caaf6e5ffa"},
+]
+antlr4-python3-runtime = [
+ {file = "antlr4-python3-runtime-4.9.3.tar.gz", hash = "sha256:f224469b4168294902bb1efa80a8bf7855f24c99aef99cbefc1bcd3cce77881b"},
+]
+anyio = [
+ {file = "anyio-3.6.2-py3-none-any.whl", hash = "sha256:fbbe32bd270d2a2ef3ed1c5d45041250284e31fc0a4df4a5a6071842051a51e3"},
+ {file = "anyio-3.6.2.tar.gz", hash = "sha256:25ea0d673ae30af41a0c442f81cf3b38c7e79fdc7b60335a4c14e05eb0947421"},
+]
+appnope = [
+ {file = "appnope-0.1.3-py2.py3-none-any.whl", hash = "sha256:265a455292d0bd8a72453494fa24df5a11eb18373a60c7c0430889f22548605e"},
+ {file = "appnope-0.1.3.tar.gz", hash = "sha256:02bd91c4de869fbb1e1c50aafc4098827a7a54ab2f39d9dcba6c9547ed920e24"},
+]
+argon2-cffi = [
+ {file = "argon2-cffi-21.3.0.tar.gz", hash = "sha256:d384164d944190a7dd7ef22c6aa3ff197da12962bd04b17f64d4e93d934dba5b"},
+ {file = "argon2_cffi-21.3.0-py3-none-any.whl", hash = "sha256:8c976986f2c5c0e5000919e6de187906cfd81fb1c72bf9d88c01177e77da7f80"},
+]
+argon2-cffi-bindings = [
+ {file = "argon2-cffi-bindings-21.2.0.tar.gz", hash = "sha256:bb89ceffa6c791807d1305ceb77dbfacc5aa499891d2c55661c6459651fc39e3"},
+ {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:ccb949252cb2ab3a08c02024acb77cfb179492d5701c7cbdbfd776124d4d2367"},
+ {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9524464572e12979364b7d600abf96181d3541da11e23ddf565a32e70bd4dc0d"},
+ {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b746dba803a79238e925d9046a63aa26bf86ab2a2fe74ce6b009a1c3f5c8f2ae"},
+ {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:58ed19212051f49a523abb1dbe954337dc82d947fb6e5a0da60f7c8471a8476c"},
+ {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:bd46088725ef7f58b5a1ef7ca06647ebaf0eb4baff7d1d0d177c6cc8744abd86"},
+ {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-musllinux_1_1_i686.whl", hash = "sha256:8cd69c07dd875537a824deec19f978e0f2078fdda07fd5c42ac29668dda5f40f"},
+ {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:f1152ac548bd5b8bcecfb0b0371f082037e47128653df2e8ba6e914d384f3c3e"},
+ {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-win32.whl", hash = "sha256:603ca0aba86b1349b147cab91ae970c63118a0f30444d4bc80355937c950c082"},
+ {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-win_amd64.whl", hash = "sha256:b2ef1c30440dbbcba7a5dc3e319408b59676e2e039e2ae11a8775ecf482b192f"},
+ {file = "argon2_cffi_bindings-21.2.0-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:e415e3f62c8d124ee16018e491a009937f8cf7ebf5eb430ffc5de21b900dad93"},
+ {file = "argon2_cffi_bindings-21.2.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:3e385d1c39c520c08b53d63300c3ecc28622f076f4c2b0e6d7e796e9f6502194"},
+ {file = "argon2_cffi_bindings-21.2.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c3e3cc67fdb7d82c4718f19b4e7a87123caf8a93fde7e23cf66ac0337d3cb3f"},
+ {file = "argon2_cffi_bindings-21.2.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6a22ad9800121b71099d0fb0a65323810a15f2e292f2ba450810a7316e128ee5"},
+ {file = "argon2_cffi_bindings-21.2.0-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f9f8b450ed0547e3d473fdc8612083fd08dd2120d6ac8f73828df9b7d45bb351"},
+ {file = "argon2_cffi_bindings-21.2.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:93f9bf70084f97245ba10ee36575f0c3f1e7d7724d67d8e5b08e61787c320ed7"},
+ {file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:3b9ef65804859d335dc6b31582cad2c5166f0c3e7975f324d9ffaa34ee7e6583"},
+ {file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d4966ef5848d820776f5f562a7d45fdd70c2f330c961d0d745b784034bd9f48d"},
+ {file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:20ef543a89dee4db46a1a6e206cd015360e5a75822f76df533845c3cbaf72670"},
+ {file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ed2937d286e2ad0cc79a7087d3c272832865f779430e0cc2b4f3718d3159b0cb"},
+ {file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:5e00316dabdaea0b2dd82d141cc66889ced0cdcbfa599e8b471cf22c620c329a"},
+]
+async-timeout = [
+ {file = "async-timeout-4.0.2.tar.gz", hash = "sha256:2163e1640ddb52b7a8c80d0a67a08587e5d245cc9c553a74a847056bc2976b15"},
+ {file = "async_timeout-4.0.2-py3-none-any.whl", hash = "sha256:8ca1e4fcf50d07413d66d1a5e416e42cfdf5851c981d679a09851a6853383b3c"},
+]
+asynctest = [
+ {file = "asynctest-0.13.0-py3-none-any.whl", hash = "sha256:5da6118a7e6d6b54d83a8f7197769d046922a44d2a99c21382f0a6e4fadae676"},
+ {file = "asynctest-0.13.0.tar.gz", hash = "sha256:c27862842d15d83e6a34eb0b2866c323880eb3a75e4485b079ea11748fd77fac"},
+]
+atomicwrites = [
+ {file = "atomicwrites-1.4.0-py2.py3-none-any.whl", hash = "sha256:6d1784dea7c0c8d4a5172b6c620f40b6e4cbfdf96d783691f2e1302a7b88e197"},
+ {file = "atomicwrites-1.4.0.tar.gz", hash = "sha256:ae70396ad1a434f9c7046fd2dd196fc04b12f9e91ffb859164193be8b6168a7a"},
+]
+attrs = [
+ {file = "attrs-21.4.0-py2.py3-none-any.whl", hash = "sha256:2d27e3784d7a565d36ab851fe94887c5eccd6a463168875832a1be79c82828b4"},
+ {file = "attrs-21.4.0.tar.gz", hash = "sha256:626ba8234211db98e869df76230a137c4c40a12d72445c45d5f5b716f076e2fd"},
+]
+awscli = [
+ {file = "awscli-1.25.22-py3-none-any.whl", hash = "sha256:2629a680e491dbb0a7e50c5ca68977c15abfb26067bf45c13a0fd6fdb9575866"},
+ {file = "awscli-1.25.22.tar.gz", hash = "sha256:070c03820b82cf79a571d27570e92f139c3aaefcb1d821f5e3cac348ac14fea3"},
+]
+awscrt = [
+ {file = "awscrt-0.13.8-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:dcf89e6296a27a283098e8507ef527cebbc657fe8dc43835a7dd703abd368b00"},
+ {file = "awscrt-0.13.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:960e31911f9c4e9276e91e37c93be41fb70b1f8f899b7fb7b879ff14bc5e8566"},
+ {file = "awscrt-0.13.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:96fd1a41f4352e097801a2e52c5ebb26f2a8fe42808a1f7e8286676932324dc9"},
+ {file = "awscrt-0.13.8-cp310-cp310-win32.whl", hash = "sha256:80a78d830db6480c8529e4890820976fe0861d3089c73d33f447641e218232e2"},
+ {file = "awscrt-0.13.8-cp310-cp310-win_amd64.whl", hash = "sha256:9d0722164b156e7e69a075e7a1ad4c434a62066c335f575eacb8dce7c4e83269"},
+ {file = "awscrt-0.13.8-cp35-cp35m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b550bfccc314a5c1160febf8c6961fb31a1f42b7fb18427ea6fe6cf1e29dd633"},
+ {file = "awscrt-0.13.8-cp35-cp35m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:6341686bcd584c59af88a4f1ec10483008a298e99e6e1331ff35c268f61501b1"},
+ {file = "awscrt-0.13.8-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:e189af34af5eab00448583a6084d08ade0acf83e77892e6add8d35c4033e9215"},
+ {file = "awscrt-0.13.8-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ef93eb4f22d46563d5d5b090ce958b65ede8c620c93b49f2bf289a4065a2204a"},
+ {file = "awscrt-0.13.8-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:320435f0d49bc1bd83eeb958f7790c5826920b2c7ea8bbf7ebfd88d14fe69f58"},
+ {file = "awscrt-0.13.8-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:42563b5f395ac3feb960ac0034b99bf917fe54d305868549c92f3028a9dea2b6"},
+ {file = "awscrt-0.13.8-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:af7e50ffe858d0093e97b2c80c587f4c55a3e33a9a738f5bc404ad1ee390eaac"},
+ {file = "awscrt-0.13.8-cp36-cp36m-win32.whl", hash = "sha256:56270397a7464c81dba21a072dd2f89dc6425af3f6b5fe83685d44c436af37cf"},
+ {file = "awscrt-0.13.8-cp36-cp36m-win_amd64.whl", hash = "sha256:9fed5b0d70b4485dfca0f9db72ec25a8396eff090feec9269c63772f4406279e"},
+ {file = "awscrt-0.13.8-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:4f5c51c3985c07ea4861574084886508056f14863c45c096154f73e5ea594d4c"},
+ {file = "awscrt-0.13.8-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70a5a684e95df54e948c762452183cd13011c5ffc5ce8fb0d4e65836bfd7319a"},
+ {file = "awscrt-0.13.8-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e494d2781ef5263665ffc665a5cbab0a7ed22f7c02a2ce92982791b0958666b"},
+ {file = "awscrt-0.13.8-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e70c760b6a82b3d9b1166540592600f85ed5da7a7fc9f16913b16dae6932ef79"},
+ {file = "awscrt-0.13.8-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:79b7f559d49aa37d7dbfd555d27d7bd3ff5e5d402790ced53678e84a55c5682d"},
+ {file = "awscrt-0.13.8-cp37-cp37m-win32.whl", hash = "sha256:24c707ea9f117d4dc1433e91901c2ff3a650265f89b0980f3ab1c8eb3d50e746"},
+ {file = "awscrt-0.13.8-cp37-cp37m-win_amd64.whl", hash = "sha256:9e686907c3b0a5e6dcd7be670a6ef097dae3b5f516778dceb9c8057e22b3d37d"},
+ {file = "awscrt-0.13.8-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9866100a5095dcd092a263ec37e1f5c3718675891ee81c358b4480d69c4bbc2e"},
+ {file = "awscrt-0.13.8-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eaecd6f0a3af351670aaf40f7fd3b2268cccd8d75dc8162f9314d4a910d34e88"},
+ {file = "awscrt-0.13.8-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:86e02c1d00af52780cecb0b2647fca9c88991b1da5faf45f7a6e5e8f0f7e25ec"},
+ {file = "awscrt-0.13.8-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:78eae0bebfef58dc00000394f26da90a7e7d1027bc75f6415c70f9d61af87361"},
+ {file = "awscrt-0.13.8-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:94da5b0bfd6fa3baa6360ce36cd0a7d1879d524d65f9a98c5b685d31a2174be9"},
+ {file = "awscrt-0.13.8-cp38-cp38-win32.whl", hash = "sha256:0f660ff9fef359d1f19c6737ddf5054b72a632e5c93b6975e34d3bbc94360489"},
+ {file = "awscrt-0.13.8-cp38-cp38-win_amd64.whl", hash = "sha256:547c3b72c779e5f3261a8c0832bca4f750c1b8844d7b157c906233bd960978d8"},
+ {file = "awscrt-0.13.8-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:0606018249976b8a0d8333744a9b8220541808da6f482ad1baccde94e3fb8688"},
+ {file = "awscrt-0.13.8-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c540aeabdde42ecf05ac499bcde09695636ab8b97ec3e43b45a709247587d476"},
+ {file = "awscrt-0.13.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ea073fa3ef3efb7700c536f8673f7952cb58deeeef627b0b322a19cfc517a456"},
+ {file = "awscrt-0.13.8-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0fd114b358a363e744ec559f5035b5bf1bbe70159682d1ff454bfa145cfe9fdf"},
+ {file = "awscrt-0.13.8-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:5c3de367f3fc5bbc1afc1e04b3dc352db2b89e655509574e92ab6ced0f7ee386"},
+ {file = "awscrt-0.13.8-cp39-cp39-win32.whl", hash = "sha256:339bf8ac195950cbdef2c36376fba23667c348085e7a86234d25be107775e156"},
+ {file = "awscrt-0.13.8-cp39-cp39-win_amd64.whl", hash = "sha256:4d350562743d1fdf3223b442c3b8e45e438d003d640ce40f0ffe7be2ea314ef2"},
+ {file = "awscrt-0.13.8.tar.gz", hash = "sha256:9f5724007537d4ba3f09d84ac5ae2c396b8d2e0e7ae410745ca726a59a2f08cb"},
+]
+backcall = [
+ {file = "backcall-0.2.0-py2.py3-none-any.whl", hash = "sha256:fbbce6a29f263178a1f7915c1940bde0ec2b2a967566fe1c65c1dfb7422bd255"},
+ {file = "backcall-0.2.0.tar.gz", hash = "sha256:5cbdbf27be5e7cfadb448baf0aa95508f91f2bbc6c6437cd9cd06e2a4c215e1e"},
+]
+beautifulsoup4 = [
+ {file = "beautifulsoup4-4.12.2-py3-none-any.whl", hash = "sha256:bd2520ca0d9d7d12694a53d44ac482d181b4ec1888909b035a3dbf40d0f57d4a"},
+ {file = "beautifulsoup4-4.12.2.tar.gz", hash = "sha256:492bbc69dca35d12daac71c4db1bfff0c876c00ef4a2ffacce226d4638eb72da"},
+]
+black = [
+ {file = "black-22.6.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f586c26118bc6e714ec58c09df0157fe2d9ee195c764f630eb0d8e7ccce72e69"},
+ {file = "black-22.6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b270a168d69edb8b7ed32c193ef10fd27844e5c60852039599f9184460ce0807"},
+ {file = "black-22.6.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6797f58943fceb1c461fb572edbe828d811e719c24e03375fd25170ada53825e"},
+ {file = "black-22.6.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c85928b9d5f83b23cee7d0efcb310172412fbf7cb9d9ce963bd67fd141781def"},
+ {file = "black-22.6.0-cp310-cp310-win_amd64.whl", hash = "sha256:f6fe02afde060bbeef044af7996f335fbe90b039ccf3f5eb8f16df8b20f77666"},
+ {file = "black-22.6.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:cfaf3895a9634e882bf9d2363fed5af8888802d670f58b279b0bece00e9a872d"},
+ {file = "black-22.6.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:94783f636bca89f11eb5d50437e8e17fbc6a929a628d82304c80fa9cd945f256"},
+ {file = "black-22.6.0-cp36-cp36m-win_amd64.whl", hash = "sha256:2ea29072e954a4d55a2ff58971b83365eba5d3d357352a07a7a4df0d95f51c78"},
+ {file = "black-22.6.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e439798f819d49ba1c0bd9664427a05aab79bfba777a6db94fd4e56fae0cb849"},
+ {file = "black-22.6.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:187d96c5e713f441a5829e77120c269b6514418f4513a390b0499b0987f2ff1c"},
+ {file = "black-22.6.0-cp37-cp37m-win_amd64.whl", hash = "sha256:074458dc2f6e0d3dab7928d4417bb6957bb834434516f21514138437accdbe90"},
+ {file = "black-22.6.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:a218d7e5856f91d20f04e931b6f16d15356db1c846ee55f01bac297a705ca24f"},
+ {file = "black-22.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:568ac3c465b1c8b34b61cd7a4e349e93f91abf0f9371eda1cf87194663ab684e"},
+ {file = "black-22.6.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6c1734ab264b8f7929cef8ae5f900b85d579e6cbfde09d7387da8f04771b51c6"},
+ {file = "black-22.6.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c9a3ac16efe9ec7d7381ddebcc022119794872abce99475345c5a61aa18c45ad"},
+ {file = "black-22.6.0-cp38-cp38-win_amd64.whl", hash = "sha256:b9fd45787ba8aa3f5e0a0a98920c1012c884622c6c920dbe98dbd05bc7c70fbf"},
+ {file = "black-22.6.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7ba9be198ecca5031cd78745780d65a3f75a34b2ff9be5837045dce55db83d1c"},
+ {file = "black-22.6.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a3db5b6409b96d9bd543323b23ef32a1a2b06416d525d27e0f67e74f1446c8f2"},
+ {file = "black-22.6.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:560558527e52ce8afba936fcce93a7411ab40c7d5fe8c2463e279e843c0328ee"},
+ {file = "black-22.6.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b154e6bbde1e79ea3260c4b40c0b7b3109ffcdf7bc4ebf8859169a6af72cd70b"},
+ {file = "black-22.6.0-cp39-cp39-win_amd64.whl", hash = "sha256:4af5bc0e1f96be5ae9bd7aaec219c901a94d6caa2484c21983d043371c733fc4"},
+ {file = "black-22.6.0-py3-none-any.whl", hash = "sha256:ac609cf8ef5e7115ddd07d85d988d074ed00e10fbc3445aee393e70164a2219c"},
+ {file = "black-22.6.0.tar.gz", hash = "sha256:6c6d39e28aed379aec40da1c65434c77d75e65bb59a1e1c283de545fb4e7c6c9"},
+]
+bleach = [
+ {file = "bleach-6.0.0-py3-none-any.whl", hash = "sha256:33c16e3353dbd13028ab4799a0f89a83f113405c766e9c122df8a06f5b85b3f4"},
+ {file = "bleach-6.0.0.tar.gz", hash = "sha256:1a1a85c1595e07d8db14c5f09f09e6433502c51c595970edc090551f0db99414"},
+]
+botocore = [
+ {file = "botocore-1.27.22-py3-none-any.whl", hash = "sha256:7145d9b7cae87999a9f074de700d02a1b3222ee7d1863aa631ff56c5fc868035"},
+ {file = "botocore-1.27.22.tar.gz", hash = "sha256:f57cb33446deef92e552b0be0e430d475c73cf64bc9e46cdb4783cdfe39cb6bb"},
+]
+braceexpand = [
+ {file = "braceexpand-0.1.7-py2.py3-none-any.whl", hash = "sha256:91332d53de7828103dcae5773fb43bc34950b0c8160e35e0f44c4427a3b85014"},
+ {file = "braceexpand-0.1.7.tar.gz", hash = "sha256:e6e539bd20eaea53547472ff94f4fb5c3d3bf9d0a89388c4b56663aba765f705"},
+]
+cachetools = [
+ {file = "cachetools-5.2.0-py3-none-any.whl", hash = "sha256:f9f17d2aec496a9aa6b76f53e3b614c965223c061982d434d160f930c698a9db"},
+ {file = "cachetools-5.2.0.tar.gz", hash = "sha256:6a94c6402995a99c3970cc7e4884bb60b4a8639938157eeed436098bf9831757"},
+]
+certifi = [
+ {file = "certifi-2022.6.15-py3-none-any.whl", hash = "sha256:fe86415d55e84719d75f8b69414f6438ac3547d2078ab91b67e779ef69378412"},
+ {file = "certifi-2022.6.15.tar.gz", hash = "sha256:84c85a9078b11105f04f3036a9482ae10e4621616db313fe045dd24743a0820d"},
+]
+cffi = [
+ {file = "cffi-1.15.1-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:a66d3508133af6e8548451b25058d5812812ec3798c886bf38ed24a98216fab2"},
+ {file = "cffi-1.15.1-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:470c103ae716238bbe698d67ad020e1db9d9dba34fa5a899b5e21577e6d52ed2"},
+ {file = "cffi-1.15.1-cp27-cp27m-manylinux1_x86_64.whl", hash = "sha256:9ad5db27f9cabae298d151c85cf2bad1d359a1b9c686a275df03385758e2f914"},
+ {file = "cffi-1.15.1-cp27-cp27m-win32.whl", hash = "sha256:b3bbeb01c2b273cca1e1e0c5df57f12dce9a4dd331b4fa1635b8bec26350bde3"},
+ {file = "cffi-1.15.1-cp27-cp27m-win_amd64.whl", hash = "sha256:e00b098126fd45523dd056d2efba6c5a63b71ffe9f2bbe1a4fe1716e1d0c331e"},
+ {file = "cffi-1.15.1-cp27-cp27mu-manylinux1_i686.whl", hash = "sha256:d61f4695e6c866a23a21acab0509af1cdfd2c013cf256bbf5b6b5e2695827162"},
+ {file = "cffi-1.15.1-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:ed9cb427ba5504c1dc15ede7d516b84757c3e3d7868ccc85121d9310d27eed0b"},
+ {file = "cffi-1.15.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:39d39875251ca8f612b6f33e6b1195af86d1b3e60086068be9cc053aa4376e21"},
+ {file = "cffi-1.15.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:285d29981935eb726a4399badae8f0ffdff4f5050eaa6d0cfc3f64b857b77185"},
+ {file = "cffi-1.15.1-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3eb6971dcff08619f8d91607cfc726518b6fa2a9eba42856be181c6d0d9515fd"},
+ {file = "cffi-1.15.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:21157295583fe8943475029ed5abdcf71eb3911894724e360acff1d61c1d54bc"},
+ {file = "cffi-1.15.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5635bd9cb9731e6d4a1132a498dd34f764034a8ce60cef4f5319c0541159392f"},
+ {file = "cffi-1.15.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2012c72d854c2d03e45d06ae57f40d78e5770d252f195b93f581acf3ba44496e"},
+ {file = "cffi-1.15.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd86c085fae2efd48ac91dd7ccffcfc0571387fe1193d33b6394db7ef31fe2a4"},
+ {file = "cffi-1.15.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:fa6693661a4c91757f4412306191b6dc88c1703f780c8234035eac011922bc01"},
+ {file = "cffi-1.15.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:59c0b02d0a6c384d453fece7566d1c7e6b7bae4fc5874ef2ef46d56776d61c9e"},
+ {file = "cffi-1.15.1-cp310-cp310-win32.whl", hash = "sha256:cba9d6b9a7d64d4bd46167096fc9d2f835e25d7e4c121fb2ddfc6528fb0413b2"},
+ {file = "cffi-1.15.1-cp310-cp310-win_amd64.whl", hash = "sha256:ce4bcc037df4fc5e3d184794f27bdaab018943698f4ca31630bc7f84a7b69c6d"},
+ {file = "cffi-1.15.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3d08afd128ddaa624a48cf2b859afef385b720bb4b43df214f85616922e6a5ac"},
+ {file = "cffi-1.15.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3799aecf2e17cf585d977b780ce79ff0dc9b78d799fc694221ce814c2c19db83"},
+ {file = "cffi-1.15.1-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a591fe9e525846e4d154205572a029f653ada1a78b93697f3b5a8f1f2bc055b9"},
+ {file = "cffi-1.15.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3548db281cd7d2561c9ad9984681c95f7b0e38881201e157833a2342c30d5e8c"},
+ {file = "cffi-1.15.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:91fc98adde3d7881af9b59ed0294046f3806221863722ba7d8d120c575314325"},
+ {file = "cffi-1.15.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:94411f22c3985acaec6f83c6df553f2dbe17b698cc7f8ae751ff2237d96b9e3c"},
+ {file = "cffi-1.15.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:03425bdae262c76aad70202debd780501fabeaca237cdfddc008987c0e0f59ef"},
+ {file = "cffi-1.15.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:cc4d65aeeaa04136a12677d3dd0b1c0c94dc43abac5860ab33cceb42b801c1e8"},
+ {file = "cffi-1.15.1-cp311-cp311-win32.whl", hash = "sha256:a0f100c8912c114ff53e1202d0078b425bee3649ae34d7b070e9697f93c5d52d"},
+ {file = "cffi-1.15.1-cp311-cp311-win_amd64.whl", hash = "sha256:04ed324bda3cda42b9b695d51bb7d54b680b9719cfab04227cdd1e04e5de3104"},
+ {file = "cffi-1.15.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50a74364d85fd319352182ef59c5c790484a336f6db772c1a9231f1c3ed0cbd7"},
+ {file = "cffi-1.15.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e263d77ee3dd201c3a142934a086a4450861778baaeeb45db4591ef65550b0a6"},
+ {file = "cffi-1.15.1-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cec7d9412a9102bdc577382c3929b337320c4c4c4849f2c5cdd14d7368c5562d"},
+ {file = "cffi-1.15.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4289fc34b2f5316fbb762d75362931e351941fa95fa18789191b33fc4cf9504a"},
+ {file = "cffi-1.15.1-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:173379135477dc8cac4bc58f45db08ab45d228b3363adb7af79436135d028405"},
+ {file = "cffi-1.15.1-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:6975a3fac6bc83c4a65c9f9fcab9e47019a11d3d2cf7f3c0d03431bf145a941e"},
+ {file = "cffi-1.15.1-cp36-cp36m-win32.whl", hash = "sha256:2470043b93ff09bf8fb1d46d1cb756ce6132c54826661a32d4e4d132e1977adf"},
+ {file = "cffi-1.15.1-cp36-cp36m-win_amd64.whl", hash = "sha256:30d78fbc8ebf9c92c9b7823ee18eb92f2e6ef79b45ac84db507f52fbe3ec4497"},
+ {file = "cffi-1.15.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:198caafb44239b60e252492445da556afafc7d1e3ab7a1fb3f0584ef6d742375"},
+ {file = "cffi-1.15.1-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5ef34d190326c3b1f822a5b7a45f6c4535e2f47ed06fec77d3d799c450b2651e"},
+ {file = "cffi-1.15.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8102eaf27e1e448db915d08afa8b41d6c7ca7a04b7d73af6514df10a3e74bd82"},
+ {file = "cffi-1.15.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5df2768244d19ab7f60546d0c7c63ce1581f7af8b5de3eb3004b9b6fc8a9f84b"},
+ {file = "cffi-1.15.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a8c4917bd7ad33e8eb21e9a5bbba979b49d9a97acb3a803092cbc1133e20343c"},
+ {file = "cffi-1.15.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e2642fe3142e4cc4af0799748233ad6da94c62a8bec3a6648bf8ee68b1c7426"},
+ {file = "cffi-1.15.1-cp37-cp37m-win32.whl", hash = "sha256:e229a521186c75c8ad9490854fd8bbdd9a0c9aa3a524326b55be83b54d4e0ad9"},
+ {file = "cffi-1.15.1-cp37-cp37m-win_amd64.whl", hash = "sha256:a0b71b1b8fbf2b96e41c4d990244165e2c9be83d54962a9a1d118fd8657d2045"},
+ {file = "cffi-1.15.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:320dab6e7cb2eacdf0e658569d2575c4dad258c0fcc794f46215e1e39f90f2c3"},
+ {file = "cffi-1.15.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1e74c6b51a9ed6589199c787bf5f9875612ca4a8a0785fb2d4a84429badaf22a"},
+ {file = "cffi-1.15.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5c84c68147988265e60416b57fc83425a78058853509c1b0629c180094904a5"},
+ {file = "cffi-1.15.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3b926aa83d1edb5aa5b427b4053dc420ec295a08e40911296b9eb1b6170f6cca"},
+ {file = "cffi-1.15.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:87c450779d0914f2861b8526e035c5e6da0a3199d8f1add1a665e1cbc6fc6d02"},
+ {file = "cffi-1.15.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f2c9f67e9821cad2e5f480bc8d83b8742896f1242dba247911072d4fa94c192"},
+ {file = "cffi-1.15.1-cp38-cp38-win32.whl", hash = "sha256:8b7ee99e510d7b66cdb6c593f21c043c248537a32e0bedf02e01e9553a172314"},
+ {file = "cffi-1.15.1-cp38-cp38-win_amd64.whl", hash = "sha256:00a9ed42e88df81ffae7a8ab6d9356b371399b91dbdf0c3cb1e84c03a13aceb5"},
+ {file = "cffi-1.15.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:54a2db7b78338edd780e7ef7f9f6c442500fb0d41a5a4ea24fff1c929d5af585"},
+ {file = "cffi-1.15.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:fcd131dd944808b5bdb38e6f5b53013c5aa4f334c5cad0c72742f6eba4b73db0"},
+ {file = "cffi-1.15.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7473e861101c9e72452f9bf8acb984947aa1661a7704553a9f6e4baa5ba64415"},
+ {file = "cffi-1.15.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c9a799e985904922a4d207a94eae35c78ebae90e128f0c4e521ce339396be9d"},
+ {file = "cffi-1.15.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3bcde07039e586f91b45c88f8583ea7cf7a0770df3a1649627bf598332cb6984"},
+ {file = "cffi-1.15.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:33ab79603146aace82c2427da5ca6e58f2b3f2fb5da893ceac0c42218a40be35"},
+ {file = "cffi-1.15.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d598b938678ebf3c67377cdd45e09d431369c3b1a5b331058c338e201f12b27"},
+ {file = "cffi-1.15.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:db0fbb9c62743ce59a9ff687eb5f4afbe77e5e8403d6697f7446e5f609976f76"},
+ {file = "cffi-1.15.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:98d85c6a2bef81588d9227dde12db8a7f47f639f4a17c9ae08e773aa9c697bf3"},
+ {file = "cffi-1.15.1-cp39-cp39-win32.whl", hash = "sha256:40f4774f5a9d4f5e344f31a32b5096977b5d48560c5592e2f3d2c4374bd543ee"},
+ {file = "cffi-1.15.1-cp39-cp39-win_amd64.whl", hash = "sha256:70df4e3b545a17496c9b3f41f5115e69a4f2e77e94e1d2a8e1070bc0c38c8a3c"},
+ {file = "cffi-1.15.1.tar.gz", hash = "sha256:d400bfb9a37b1351253cb402671cea7e89bdecc294e8016a707f6d1d8ac934f9"},
+]
+cfgv = [
+ {file = "cfgv-3.3.1-py2.py3-none-any.whl", hash = "sha256:c6a0883f3917a037485059700b9e75da2464e6c27051014ad85ba6aaa5884426"},
+ {file = "cfgv-3.3.1.tar.gz", hash = "sha256:f5a830efb9ce7a445376bb66ec94c638a9787422f96264c98edc6bdeed8ab736"},
+]
+charset-normalizer = [
+ {file = "charset-normalizer-2.1.0.tar.gz", hash = "sha256:575e708016ff3a5e3681541cb9d79312c416835686d054a23accb873b254f413"},
+ {file = "charset_normalizer-2.1.0-py3-none-any.whl", hash = "sha256:5189b6f22b01957427f35b6a08d9a0bc45b46d3788ef5a92e978433c7a35f8a5"},
+]
+click = [
+ {file = "click-8.1.3-py3-none-any.whl", hash = "sha256:bb4d8133cb15a609f44e8213d9b391b0809795062913b383c62be0ee95b1db48"},
+ {file = "click-8.1.3.tar.gz", hash = "sha256:7682dc8afb30297001674575ea00d1814d808d6a36af415a82bd481d37ba7b8e"},
+]
+cloudpickle = [
+ {file = "cloudpickle-2.1.0-py3-none-any.whl", hash = "sha256:b5c434f75c34624eedad3a14f2be5ac3b5384774d5b0e3caf905c21479e6c4b1"},
+ {file = "cloudpickle-2.1.0.tar.gz", hash = "sha256:bb233e876a58491d9590a676f93c7a5473a08f747d5ab9df7f9ce564b3e7938e"},
+]
+colorama = [
+ {file = "colorama-0.4.4-py2.py3-none-any.whl", hash = "sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"},
+ {file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"},
+]
+cycler = [
+ {file = "cycler-0.11.0-py3-none-any.whl", hash = "sha256:3a27e95f763a428a739d2add979fa7494c912a32c17c4c38c4d5f082cad165a3"},
+ {file = "cycler-0.11.0.tar.gz", hash = "sha256:9c87405839a19696e837b3b818fed3f5f69f16f1eec1a1ad77e043dcea9c772f"},
+]
+databricks-cli = [
+ {file = "databricks-cli-0.17.3.tar.gz", hash = "sha256:2f00f3e70e859809f0595885ec76fc73ba60ad0cccd69564f7df5d95b6c90066"},
+ {file = "databricks_cli-0.17.3-py2-none-any.whl", hash = "sha256:f090c2e4f99c39d69a7f7228e6c7df8cb1cebd5fddad6292e0625daf29d4be01"},
+]
+debugpy = [
+ {file = "debugpy-1.6.7-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:b3e7ac809b991006ad7f857f016fa92014445085711ef111fdc3f74f66144096"},
+ {file = "debugpy-1.6.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3876611d114a18aafef6383695dfc3f1217c98a9168c1aaf1a02b01ec7d8d1e"},
+ {file = "debugpy-1.6.7-cp310-cp310-win32.whl", hash = "sha256:33edb4afa85c098c24cc361d72ba7c21bb92f501104514d4ffec1fb36e09c01a"},
+ {file = "debugpy-1.6.7-cp310-cp310-win_amd64.whl", hash = "sha256:ed6d5413474e209ba50b1a75b2d9eecf64d41e6e4501977991cdc755dc83ab0f"},
+ {file = "debugpy-1.6.7-cp37-cp37m-macosx_10_15_x86_64.whl", hash = "sha256:38ed626353e7c63f4b11efad659be04c23de2b0d15efff77b60e4740ea685d07"},
+ {file = "debugpy-1.6.7-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:279d64c408c60431c8ee832dfd9ace7c396984fd7341fa3116aee414e7dcd88d"},
+ {file = "debugpy-1.6.7-cp37-cp37m-win32.whl", hash = "sha256:dbe04e7568aa69361a5b4c47b4493d5680bfa3a911d1e105fbea1b1f23f3eb45"},
+ {file = "debugpy-1.6.7-cp37-cp37m-win_amd64.whl", hash = "sha256:f90a2d4ad9a035cee7331c06a4cf2245e38bd7c89554fe3b616d90ab8aab89cc"},
+ {file = "debugpy-1.6.7-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:5224eabbbeddcf1943d4e2821876f3e5d7d383f27390b82da5d9558fd4eb30a9"},
+ {file = "debugpy-1.6.7-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bae1123dff5bfe548ba1683eb972329ba6d646c3a80e6b4c06cd1b1dd0205e9b"},
+ {file = "debugpy-1.6.7-cp38-cp38-win32.whl", hash = "sha256:9cd10cf338e0907fdcf9eac9087faa30f150ef5445af5a545d307055141dd7a4"},
+ {file = "debugpy-1.6.7-cp38-cp38-win_amd64.whl", hash = "sha256:aaf6da50377ff4056c8ed470da24632b42e4087bc826845daad7af211e00faad"},
+ {file = "debugpy-1.6.7-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:0679b7e1e3523bd7d7869447ec67b59728675aadfc038550a63a362b63029d2c"},
+ {file = "debugpy-1.6.7-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de86029696e1b3b4d0d49076b9eba606c226e33ae312a57a46dca14ff370894d"},
+ {file = "debugpy-1.6.7-cp39-cp39-win32.whl", hash = "sha256:d71b31117779d9a90b745720c0eab54ae1da76d5b38c8026c654f4a066b0130a"},
+ {file = "debugpy-1.6.7-cp39-cp39-win_amd64.whl", hash = "sha256:c0ff93ae90a03b06d85b2c529eca51ab15457868a377c4cc40a23ab0e4e552a3"},
+ {file = "debugpy-1.6.7-py2.py3-none-any.whl", hash = "sha256:53f7a456bc50706a0eaabecf2d3ce44c4d5010e46dfc65b6b81a518b42866267"},
+ {file = "debugpy-1.6.7.zip", hash = "sha256:c4c2f0810fa25323abfdfa36cbbbb24e5c3b1a42cb762782de64439c575d67f2"},
+]
+decorator = [
+ {file = "decorator-4.4.2-py2.py3-none-any.whl", hash = "sha256:41fa54c2a0cc4ba648be4fd43cff00aedf5b9465c9bf18d64325bc225f08f760"},
+ {file = "decorator-4.4.2.tar.gz", hash = "sha256:e3a62f0520172440ca0dcc823749319382e377f37f140a0b99ef45fecb84bfe7"},
+]
+decord = [
+ {file = "decord-0.6.0-cp36-cp36m-macosx_10_15_x86_64.whl", hash = "sha256:85ef90d2f872384657d7774cc486c237c5b12df62d4ac5cb5c8d6001fa611323"},
+ {file = "decord-0.6.0-cp37-cp37m-macosx_10_15_x86_64.whl", hash = "sha256:9c20674964fb1490c677bd911d2023d2a09fec7a58a4bb0b7ddf1ccc269f107a"},
+ {file = "decord-0.6.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:a0eb1258beade34dceb29d97856a7764d179db1b5182899b61874f3418a1abc8"},
+ {file = "decord-0.6.0-py3-none-manylinux2010_x86_64.whl", hash = "sha256:51997f20be8958e23b7c4061ba45d0efcd86bffd5fe81c695d0befee0d442976"},
+ {file = "decord-0.6.0-py3-none-win_amd64.whl", hash = "sha256:02665d7c4f1193a330205a791bc128f7e108eb6ae5b67144437a02f700943bad"},
+]
+defusedxml = [
+ {file = "defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61"},
+ {file = "defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69"},
+]
+distlib = [
+ {file = "distlib-0.3.4-py2.py3-none-any.whl", hash = "sha256:6564fe0a8f51e734df6333d08b8b94d4ea8ee6b99b5ed50613f731fd4089f34b"},
+ {file = "distlib-0.3.4.zip", hash = "sha256:e4b58818180336dc9c529bfb9a0b58728ffc09ad92027a3f30b7cd91e3458579"},
+]
+docker = [
+ {file = "docker-6.0.0-py3-none-any.whl", hash = "sha256:6e06ee8eca46cd88733df09b6b80c24a1a556bc5cb1e1ae54b2c239886d245cf"},
+ {file = "docker-6.0.0.tar.gz", hash = "sha256:19e330470af40167d293b0352578c1fa22d74b34d3edf5d4ff90ebc203bbb2f1"},
+]
+docutils = [
+ {file = "docutils-0.16-py2.py3-none-any.whl", hash = "sha256:0c5b78adfbf7762415433f5515cd5c9e762339e23369dbe8000d84a4bf4ab3af"},
+ {file = "docutils-0.16.tar.gz", hash = "sha256:c2de3a60e9e7d07be26b7f2b00ca0309c207e06c100f9cc2a94931fc75a478fc"},
+]
+einops = [
+ {file = "einops-0.6.0-py3-none-any.whl", hash = "sha256:c7b187a5dc725f079860ec2d330c1820448948622d826273345a8dd8d5f695bd"},
+ {file = "einops-0.6.0.tar.gz", hash = "sha256:6f6c78739316a2e3ccbce8052310497e69da092935e4173f2e76ec4e3a336a35"},
+]
+entrypoints = [
+ {file = "entrypoints-0.4-py3-none-any.whl", hash = "sha256:f174b5ff827504fd3cd97cc3f8649f3693f51538c7e4bdf3ef002c8429d42f9f"},
+ {file = "entrypoints-0.4.tar.gz", hash = "sha256:b706eddaa9218a19ebcd67b56818f05bb27589b1ca9e8d797b74affad4ccacd4"},
+]
+fastjsonschema = [
+ {file = "fastjsonschema-2.16.3-py3-none-any.whl", hash = "sha256:04fbecc94300436f628517b05741b7ea009506ce8f946d40996567c669318490"},
+ {file = "fastjsonschema-2.16.3.tar.gz", hash = "sha256:4a30d6315a68c253cfa8f963b9697246315aa3db89f98b97235e345dedfb0b8e"},
+]
+filelock = [
+ {file = "filelock-3.7.1-py3-none-any.whl", hash = "sha256:37def7b658813cda163b56fc564cdc75e86d338246458c4c28ae84cabefa2404"},
+ {file = "filelock-3.7.1.tar.gz", hash = "sha256:3a0fd85166ad9dbab54c9aec96737b744106dc5f15c0b09a6744a445299fcf04"},
+]
+flake8 = [
+ {file = "flake8-4.0.1-py2.py3-none-any.whl", hash = "sha256:479b1304f72536a55948cb40a32dce8bb0ffe3501e26eaf292c7e60eb5e0428d"},
+ {file = "flake8-4.0.1.tar.gz", hash = "sha256:806e034dda44114815e23c16ef92f95c91e4c71100ff52813adf7132a6ad870d"},
+]
+flake8-bugbear = [
+ {file = "flake8-bugbear-22.7.1.tar.gz", hash = "sha256:e450976a07e4f9d6c043d4f72b17ec1baf717fe37f7997009c8ae58064f88305"},
+ {file = "flake8_bugbear-22.7.1-py3-none-any.whl", hash = "sha256:db5d7a831ef4412a224b26c708967ff816818cabae415e76b8c58df156c4b8e5"},
+]
+flake8-docstrings = [
+ {file = "flake8-docstrings-1.6.0.tar.gz", hash = "sha256:9fe7c6a306064af8e62a055c2f61e9eb1da55f84bb39caef2b84ce53708ac34b"},
+ {file = "flake8_docstrings-1.6.0-py2.py3-none-any.whl", hash = "sha256:99cac583d6c7e32dd28bbfbef120a7c0d1b6dde4adb5a9fd441c4227a6534bde"},
+]
+flake8-isort = [
+ {file = "flake8-isort-4.1.1.tar.gz", hash = "sha256:d814304ab70e6e58859bc5c3e221e2e6e71c958e7005239202fee19c24f82717"},
+ {file = "flake8_isort-4.1.1-py3-none-any.whl", hash = "sha256:c4e8b6dcb7be9b71a02e6e5d4196cefcef0f3447be51e82730fb336fff164949"},
+]
+flake8-tidy-imports = [
+ {file = "flake8-tidy-imports-4.8.0.tar.gz", hash = "sha256:df44f9c841b5dfb3a7a1f0da8546b319d772c2a816a1afefcce43e167a593d83"},
+ {file = "flake8_tidy_imports-4.8.0-py3-none-any.whl", hash = "sha256:25bd9799358edefa0e010ce2c587b093c3aba942e96aeaa99b6d0500ae1bf09c"},
+]
+Flask = [
+ {file = "Flask-2.1.3-py3-none-any.whl", hash = "sha256:9013281a7402ad527f8fd56375164f3aa021ecfaff89bfe3825346c24f87e04c"},
+ {file = "Flask-2.1.3.tar.gz", hash = "sha256:15972e5017df0575c3d6c090ba168b6db90259e620ac8d7ea813a396bad5b6cb"},
+]
+fonttools = [
+ {file = "fonttools-4.33.3-py3-none-any.whl", hash = "sha256:f829c579a8678fa939a1d9e9894d01941db869de44390adb49ce67055a06cc2a"},
+ {file = "fonttools-4.33.3.zip", hash = "sha256:c0fdcfa8ceebd7c1b2021240bd46ef77aa8e7408cf10434be55df52384865f8e"},
+]
+frozenlist = [
+ {file = "frozenlist-1.3.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:d2257aaba9660f78c7b1d8fea963b68f3feffb1a9d5d05a18401ca9eb3e8d0a3"},
+ {file = "frozenlist-1.3.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:4a44ebbf601d7bac77976d429e9bdb5a4614f9f4027777f9e54fd765196e9d3b"},
+ {file = "frozenlist-1.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:45334234ec30fc4ea677f43171b18a27505bfb2dba9aca4398a62692c0ea8868"},
+ {file = "frozenlist-1.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:47be22dc27ed933d55ee55845d34a3e4e9f6fee93039e7f8ebadb0c2f60d403f"},
+ {file = "frozenlist-1.3.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:03a7dd1bfce30216a3f51a84e6dd0e4a573d23ca50f0346634916ff105ba6e6b"},
+ {file = "frozenlist-1.3.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:691ddf6dc50480ce49f68441f1d16a4c3325887453837036e0fb94736eae1e58"},
+ {file = "frozenlist-1.3.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bde99812f237f79eaf3f04ebffd74f6718bbd216101b35ac7955c2d47c17da02"},
+ {file = "frozenlist-1.3.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6a202458d1298ced3768f5a7d44301e7c86defac162ace0ab7434c2e961166e8"},
+ {file = "frozenlist-1.3.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:b9e3e9e365991f8cc5f5edc1fd65b58b41d0514a6a7ad95ef5c7f34eb49b3d3e"},
+ {file = "frozenlist-1.3.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:04cb491c4b1c051734d41ea2552fde292f5f3a9c911363f74f39c23659c4af78"},
+ {file = "frozenlist-1.3.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:436496321dad302b8b27ca955364a439ed1f0999311c393dccb243e451ff66aa"},
+ {file = "frozenlist-1.3.0-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:754728d65f1acc61e0f4df784456106e35afb7bf39cfe37227ab00436fb38676"},
+ {file = "frozenlist-1.3.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6eb275c6385dd72594758cbe96c07cdb9bd6becf84235f4a594bdf21e3596c9d"},
+ {file = "frozenlist-1.3.0-cp310-cp310-win32.whl", hash = "sha256:e30b2f9683812eb30cf3f0a8e9f79f8d590a7999f731cf39f9105a7c4a39489d"},
+ {file = "frozenlist-1.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:f7353ba3367473d1d616ee727945f439e027f0bb16ac1a750219a8344d1d5d3c"},
+ {file = "frozenlist-1.3.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:88aafd445a233dbbf8a65a62bc3249a0acd0d81ab18f6feb461cc5a938610d24"},
+ {file = "frozenlist-1.3.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4406cfabef8f07b3b3af0f50f70938ec06d9f0fc26cbdeaab431cbc3ca3caeaa"},
+ {file = "frozenlist-1.3.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8cf829bd2e2956066dd4de43fd8ec881d87842a06708c035b37ef632930505a2"},
+ {file = "frozenlist-1.3.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:603b9091bd70fae7be28bdb8aa5c9990f4241aa33abb673390a7f7329296695f"},
+ {file = "frozenlist-1.3.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:25af28b560e0c76fa41f550eacb389905633e7ac02d6eb3c09017fa1c8cdfde1"},
+ {file = "frozenlist-1.3.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:94c7a8a9fc9383b52c410a2ec952521906d355d18fccc927fca52ab575ee8b93"},
+ {file = "frozenlist-1.3.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:65bc6e2fece04e2145ab6e3c47428d1bbc05aede61ae365b2c1bddd94906e478"},
+ {file = "frozenlist-1.3.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:3f7c935c7b58b0d78c0beea0c7358e165f95f1fd8a7e98baa40d22a05b4a8141"},
+ {file = "frozenlist-1.3.0-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:bd89acd1b8bb4f31b47072615d72e7f53a948d302b7c1d1455e42622de180eae"},
+ {file = "frozenlist-1.3.0-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:6983a31698490825171be44ffbafeaa930ddf590d3f051e397143a5045513b01"},
+ {file = "frozenlist-1.3.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:adac9700675cf99e3615eb6a0eb5e9f5a4143c7d42c05cea2e7f71c27a3d0846"},
+ {file = "frozenlist-1.3.0-cp37-cp37m-win32.whl", hash = "sha256:0c36e78b9509e97042ef869c0e1e6ef6429e55817c12d78245eb915e1cca7468"},
+ {file = "frozenlist-1.3.0-cp37-cp37m-win_amd64.whl", hash = "sha256:57f4d3f03a18facacb2a6bcd21bccd011e3b75d463dc49f838fd699d074fabd1"},
+ {file = "frozenlist-1.3.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:8c905a5186d77111f02144fab5b849ab524f1e876a1e75205cd1386a9be4b00a"},
+ {file = "frozenlist-1.3.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b5009062d78a8c6890d50b4e53b0ddda31841b3935c1937e2ed8c1bda1c7fb9d"},
+ {file = "frozenlist-1.3.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2fdc3cd845e5a1f71a0c3518528bfdbfe2efaf9886d6f49eacc5ee4fd9a10953"},
+ {file = "frozenlist-1.3.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:92e650bd09b5dda929523b9f8e7f99b24deac61240ecc1a32aeba487afcd970f"},
+ {file = "frozenlist-1.3.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:40dff8962b8eba91fd3848d857203f0bd704b5f1fa2b3fc9af64901a190bba08"},
+ {file = "frozenlist-1.3.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:768efd082074bb203c934e83a61654ed4931ef02412c2fbdecea0cff7ecd0274"},
+ {file = "frozenlist-1.3.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:006d3595e7d4108a12025ddf415ae0f6c9e736e726a5db0183326fd191b14c5e"},
+ {file = "frozenlist-1.3.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:871d42623ae15eb0b0e9df65baeee6976b2e161d0ba93155411d58ff27483ad8"},
+ {file = "frozenlist-1.3.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:aff388be97ef2677ae185e72dc500d19ecaf31b698986800d3fc4f399a5e30a5"},
+ {file = "frozenlist-1.3.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:9f892d6a94ec5c7b785e548e42722e6f3a52f5f32a8461e82ac3e67a3bd073f1"},
+ {file = "frozenlist-1.3.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:e982878792c971cbd60ee510c4ee5bf089a8246226dea1f2138aa0bb67aff148"},
+ {file = "frozenlist-1.3.0-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:c6c321dd013e8fc20735b92cb4892c115f5cdb82c817b1e5b07f6b95d952b2f0"},
+ {file = "frozenlist-1.3.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:30530930410855c451bea83f7b272fb1c495ed9d5cc72895ac29e91279401db3"},
+ {file = "frozenlist-1.3.0-cp38-cp38-win32.whl", hash = "sha256:40ec383bc194accba825fbb7d0ef3dda5736ceab2375462f1d8672d9f6b68d07"},
+ {file = "frozenlist-1.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:f20baa05eaa2bcd5404c445ec51aed1c268d62600362dc6cfe04fae34a424bd9"},
+ {file = "frozenlist-1.3.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:0437fe763fb5d4adad1756050cbf855bbb2bf0d9385c7bb13d7a10b0dd550486"},
+ {file = "frozenlist-1.3.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b684c68077b84522b5c7eafc1dc735bfa5b341fb011d5552ebe0968e22ed641c"},
+ {file = "frozenlist-1.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:93641a51f89473837333b2f8100f3f89795295b858cd4c7d4a1f18e299dc0a4f"},
+ {file = "frozenlist-1.3.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d6d32ff213aef0fd0bcf803bffe15cfa2d4fde237d1d4838e62aec242a8362fa"},
+ {file = "frozenlist-1.3.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:31977f84828b5bb856ca1eb07bf7e3a34f33a5cddce981d880240ba06639b94d"},
+ {file = "frozenlist-1.3.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3c62964192a1c0c30b49f403495911298810bada64e4f03249ca35a33ca0417a"},
+ {file = "frozenlist-1.3.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4eda49bea3602812518765810af732229b4291d2695ed24a0a20e098c45a707b"},
+ {file = "frozenlist-1.3.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:acb267b09a509c1df5a4ca04140da96016f40d2ed183cdc356d237286c971b51"},
+ {file = "frozenlist-1.3.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:e1e26ac0a253a2907d654a37e390904426d5ae5483150ce3adedb35c8c06614a"},
+ {file = "frozenlist-1.3.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f96293d6f982c58ebebb428c50163d010c2f05de0cde99fd681bfdc18d4b2dc2"},
+ {file = "frozenlist-1.3.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:e84cb61b0ac40a0c3e0e8b79c575161c5300d1d89e13c0e02f76193982f066ed"},
+ {file = "frozenlist-1.3.0-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:ff9310f05b9d9c5c4dd472983dc956901ee6cb2c3ec1ab116ecdde25f3ce4951"},
+ {file = "frozenlist-1.3.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d26b650b71fdc88065b7a21f8ace70175bcf3b5bdba5ea22df4bfd893e795a3b"},
+ {file = "frozenlist-1.3.0-cp39-cp39-win32.whl", hash = "sha256:01a73627448b1f2145bddb6e6c2259988bb8aee0fb361776ff8604b99616cd08"},
+ {file = "frozenlist-1.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:772965f773757a6026dea111a15e6e2678fbd6216180f82a48a40b27de1ee2ab"},
+ {file = "frozenlist-1.3.0.tar.gz", hash = "sha256:ce6f2ba0edb7b0c1d8976565298ad2deba6f8064d2bebb6ffce2ca896eb35b0b"},
+]
+fsspec = [
+ {file = "fsspec-2022.7.1-py3-none-any.whl", hash = "sha256:36c5a8e7c4fc20cf32ef6934ac0a122accc8a593ddc8478d30c3ca4dbbd95500"},
+ {file = "fsspec-2022.7.1.tar.gz", hash = "sha256:7f9fb19d811b027b97c4636c6073eb53bc4cbee2d3c4b33fa88b9f26906fd7d7"},
+]
+ftfy = [
+ {file = "ftfy-6.1.1-py3-none-any.whl", hash = "sha256:0ffd33fce16b54cccaec78d6ec73d95ad370e5df5a25255c8966a6147bd667ca"},
+ {file = "ftfy-6.1.1.tar.gz", hash = "sha256:bfc2019f84fcd851419152320a6375604a0f1459c281b5b199b2cd0d2e727f8f"},
+]
+gitdb = [
+ {file = "gitdb-4.0.9-py3-none-any.whl", hash = "sha256:8033ad4e853066ba6ca92050b9df2f89301b8fc8bf7e9324d412a63f8bf1a8fd"},
+ {file = "gitdb-4.0.9.tar.gz", hash = "sha256:bac2fd45c0a1c9cf619e63a90d62bdc63892ef92387424b855792a6cabe789aa"},
+]
+GitPython = [
+ {file = "GitPython-3.1.29-py3-none-any.whl", hash = "sha256:41eea0deec2deea139b459ac03656f0dd28fc4a3387240ec1d3c259a2c47850f"},
+ {file = "GitPython-3.1.29.tar.gz", hash = "sha256:cc36bfc4a3f913e66805a28e84703e419d9c264c1077e537b54f0e1af85dbefd"},
+]
+google-auth = [
+ {file = "google-auth-2.9.0.tar.gz", hash = "sha256:3b2f9d2f436cc7c3b363d0ac66470f42fede249c3bafcc504e9f0bcbe983cff0"},
+ {file = "google_auth-2.9.0-py2.py3-none-any.whl", hash = "sha256:75b3977e7e22784607e074800048f44d6a56df589fb2abe58a11d4d20c97c314"},
+]
+google-auth-oauthlib = [
+ {file = "google-auth-oauthlib-0.4.6.tar.gz", hash = "sha256:a90a072f6993f2c327067bf65270046384cda5a8ecb20b94ea9a687f1f233a7a"},
+ {file = "google_auth_oauthlib-0.4.6-py2.py3-none-any.whl", hash = "sha256:3f2a6e802eebbb6fb736a370fbf3b055edcb6b52878bf2f26330b5e041316c73"},
+]
+greenlet = [
+ {file = "greenlet-1.1.3.post0-cp27-cp27m-macosx_10_14_x86_64.whl", hash = "sha256:949c9061b8c6d3e6e439466a9be1e787208dec6246f4ec5fffe9677b4c19fcc3"},
+ {file = "greenlet-1.1.3.post0-cp27-cp27m-manylinux1_x86_64.whl", hash = "sha256:d7815e1519a8361c5ea2a7a5864945906f8e386fa1bc26797b4d443ab11a4589"},
+ {file = "greenlet-1.1.3.post0-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:9649891ab4153f217f319914455ccf0b86986b55fc0573ce803eb998ad7d6854"},
+ {file = "greenlet-1.1.3.post0-cp27-cp27m-win32.whl", hash = "sha256:11fc7692d95cc7a6a8447bb160d98671ab291e0a8ea90572d582d57361360f05"},
+ {file = "greenlet-1.1.3.post0-cp27-cp27m-win_amd64.whl", hash = "sha256:05ae7383f968bba4211b1fbfc90158f8e3da86804878442b4fb6c16ccbcaa519"},
+ {file = "greenlet-1.1.3.post0-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:ccbe7129a282ec5797df0451ca1802f11578be018a32979131065565da89b392"},
+ {file = "greenlet-1.1.3.post0-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:4a8b58232f5b72973350c2b917ea3df0bebd07c3c82a0a0e34775fc2c1f857e9"},
+ {file = "greenlet-1.1.3.post0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:f6661b58412879a2aa099abb26d3c93e91dedaba55a6394d1fb1512a77e85de9"},
+ {file = "greenlet-1.1.3.post0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c6e942ca9835c0b97814d14f78da453241837419e0d26f7403058e8db3e38f8"},
+ {file = "greenlet-1.1.3.post0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a812df7282a8fc717eafd487fccc5ba40ea83bb5b13eb3c90c446d88dbdfd2be"},
+ {file = "greenlet-1.1.3.post0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:83a7a6560df073ec9de2b7cb685b199dfd12519bc0020c62db9d1bb522f989fa"},
+ {file = "greenlet-1.1.3.post0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:17a69967561269b691747e7f436d75a4def47e5efcbc3c573180fc828e176d80"},
+ {file = "greenlet-1.1.3.post0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:60839ab4ea7de6139a3be35b77e22e0398c270020050458b3d25db4c7c394df5"},
+ {file = "greenlet-1.1.3.post0-cp310-cp310-win_amd64.whl", hash = "sha256:8926a78192b8b73c936f3e87929931455a6a6c6c385448a07b9f7d1072c19ff3"},
+ {file = "greenlet-1.1.3.post0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:c6f90234e4438062d6d09f7d667f79edcc7c5e354ba3a145ff98176f974b8132"},
+ {file = "greenlet-1.1.3.post0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:814f26b864ed2230d3a7efe0336f5766ad012f94aad6ba43a7c54ca88dd77cba"},
+ {file = "greenlet-1.1.3.post0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8fda1139d87ce5f7bd80e80e54f9f2c6fe2f47983f1a6f128c47bf310197deb6"},
+ {file = "greenlet-1.1.3.post0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0643250dd0756f4960633f5359884f609a234d4066686754e834073d84e9b51"},
+ {file = "greenlet-1.1.3.post0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:cb863057bed786f6622982fb8b2c122c68e6e9eddccaa9fa98fd937e45ee6c4f"},
+ {file = "greenlet-1.1.3.post0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:8c0581077cf2734569f3e500fab09c0ff6a2ab99b1afcacbad09b3c2843ae743"},
+ {file = "greenlet-1.1.3.post0-cp35-cp35m-macosx_10_14_x86_64.whl", hash = "sha256:695d0d8b5ae42c800f1763c9fce9d7b94ae3b878919379150ee5ba458a460d57"},
+ {file = "greenlet-1.1.3.post0-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:5662492df0588a51d5690f6578f3bbbd803e7f8d99a99f3bf6128a401be9c269"},
+ {file = "greenlet-1.1.3.post0-cp35-cp35m-manylinux2010_x86_64.whl", hash = "sha256:bffba15cff4802ff493d6edcf20d7f94ab1c2aee7cfc1e1c7627c05f1102eee8"},
+ {file = "greenlet-1.1.3.post0-cp35-cp35m-win32.whl", hash = "sha256:7afa706510ab079fd6d039cc6e369d4535a48e202d042c32e2097f030a16450f"},
+ {file = "greenlet-1.1.3.post0-cp35-cp35m-win_amd64.whl", hash = "sha256:3a24f3213579dc8459e485e333330a921f579543a5214dbc935bc0763474ece3"},
+ {file = "greenlet-1.1.3.post0-cp36-cp36m-macosx_10_14_x86_64.whl", hash = "sha256:64e10f303ea354500c927da5b59c3802196a07468332d292aef9ddaca08d03dd"},
+ {file = "greenlet-1.1.3.post0-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:eb6ac495dccb1520667cfea50d89e26f9ffb49fa28496dea2b95720d8b45eb54"},
+ {file = "greenlet-1.1.3.post0-cp36-cp36m-manylinux2010_x86_64.whl", hash = "sha256:88720794390002b0c8fa29e9602b395093a9a766b229a847e8d88349e418b28a"},
+ {file = "greenlet-1.1.3.post0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:39464518a2abe9c505a727af7c0b4efff2cf242aa168be5f0daa47649f4d7ca8"},
+ {file = "greenlet-1.1.3.post0-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0914f02fcaa8f84f13b2df4a81645d9e82de21ed95633765dd5cc4d3af9d7403"},
+ {file = "greenlet-1.1.3.post0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:96656c5f7c95fc02c36d4f6ef32f4e94bb0b6b36e6a002c21c39785a4eec5f5d"},
+ {file = "greenlet-1.1.3.post0-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:4f74aa0092602da2069df0bc6553919a15169d77bcdab52a21f8c5242898f519"},
+ {file = "greenlet-1.1.3.post0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:3aeac044c324c1a4027dca0cde550bd83a0c0fbff7ef2c98df9e718a5086c194"},
+ {file = "greenlet-1.1.3.post0-cp36-cp36m-win32.whl", hash = "sha256:fe7c51f8a2ab616cb34bc33d810c887e89117771028e1e3d3b77ca25ddeace04"},
+ {file = "greenlet-1.1.3.post0-cp36-cp36m-win_amd64.whl", hash = "sha256:70048d7b2c07c5eadf8393e6398595591df5f59a2f26abc2f81abca09610492f"},
+ {file = "greenlet-1.1.3.post0-cp37-cp37m-macosx_10_15_x86_64.whl", hash = "sha256:66aa4e9a726b70bcbfcc446b7ba89c8cec40f405e51422c39f42dfa206a96a05"},
+ {file = "greenlet-1.1.3.post0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:025b8de2273d2809f027d347aa2541651d2e15d593bbce0d5f502ca438c54136"},
+ {file = "greenlet-1.1.3.post0-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:82a38d7d2077128a017094aff334e67e26194f46bd709f9dcdacbf3835d47ef5"},
+ {file = "greenlet-1.1.3.post0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f7d20c3267385236b4ce54575cc8e9f43e7673fc761b069c820097092e318e3b"},
+ {file = "greenlet-1.1.3.post0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c8ece5d1a99a2adcb38f69af2f07d96fb615415d32820108cd340361f590d128"},
+ {file = "greenlet-1.1.3.post0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2794eef1b04b5ba8948c72cc606aab62ac4b0c538b14806d9c0d88afd0576d6b"},
+ {file = "greenlet-1.1.3.post0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:a8d24eb5cb67996fb84633fdc96dbc04f2d8b12bfcb20ab3222d6be271616b67"},
+ {file = "greenlet-1.1.3.post0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:0120a879aa2b1ac5118bce959ea2492ba18783f65ea15821680a256dfad04754"},
+ {file = "greenlet-1.1.3.post0-cp37-cp37m-win32.whl", hash = "sha256:bef49c07fcb411c942da6ee7d7ea37430f830c482bf6e4b72d92fd506dd3a427"},
+ {file = "greenlet-1.1.3.post0-cp37-cp37m-win_amd64.whl", hash = "sha256:62723e7eb85fa52e536e516ee2ac91433c7bb60d51099293671815ff49ed1c21"},
+ {file = "greenlet-1.1.3.post0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:d25cdedd72aa2271b984af54294e9527306966ec18963fd032cc851a725ddc1b"},
+ {file = "greenlet-1.1.3.post0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:924df1e7e5db27d19b1359dc7d052a917529c95ba5b8b62f4af611176da7c8ad"},
+ {file = "greenlet-1.1.3.post0-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:ec615d2912b9ad807afd3be80bf32711c0ff9c2b00aa004a45fd5d5dde7853d9"},
+ {file = "greenlet-1.1.3.post0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0971d37ae0eaf42344e8610d340aa0ad3d06cd2eee381891a10fe771879791f9"},
+ {file = "greenlet-1.1.3.post0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:325f272eb997916b4a3fc1fea7313a8adb760934c2140ce13a2117e1b0a8095d"},
+ {file = "greenlet-1.1.3.post0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d75afcbb214d429dacdf75e03a1d6d6c5bd1fa9c35e360df8ea5b6270fb2211c"},
+ {file = "greenlet-1.1.3.post0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:5c2d21c2b768d8c86ad935e404cc78c30d53dea009609c3ef3a9d49970c864b5"},
+ {file = "greenlet-1.1.3.post0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:467b73ce5dcd89e381292fb4314aede9b12906c18fab903f995b86034d96d5c8"},
+ {file = "greenlet-1.1.3.post0-cp38-cp38-win32.whl", hash = "sha256:8149a6865b14c33be7ae760bcdb73548bb01e8e47ae15e013bf7ef9290ca309a"},
+ {file = "greenlet-1.1.3.post0-cp38-cp38-win_amd64.whl", hash = "sha256:104f29dd822be678ef6b16bf0035dcd43206a8a48668a6cae4d2fe9c7a7abdeb"},
+ {file = "greenlet-1.1.3.post0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:c8c9301e3274276d3d20ab6335aa7c5d9e5da2009cccb01127bddb5c951f8870"},
+ {file = "greenlet-1.1.3.post0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:8415239c68b2ec9de10a5adf1130ee9cb0ebd3e19573c55ba160ff0ca809e012"},
+ {file = "greenlet-1.1.3.post0-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:3c22998bfef3fcc1b15694818fc9b1b87c6cc8398198b96b6d355a7bcb8c934e"},
+ {file = "greenlet-1.1.3.post0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0aa1845944e62f358d63fcc911ad3b415f585612946b8edc824825929b40e59e"},
+ {file = "greenlet-1.1.3.post0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:890f633dc8cb307761ec566bc0b4e350a93ddd77dc172839be122be12bae3e10"},
+ {file = "greenlet-1.1.3.post0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7cf37343e43404699d58808e51f347f57efd3010cc7cee134cdb9141bd1ad9ea"},
+ {file = "greenlet-1.1.3.post0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:5edf75e7fcfa9725064ae0d8407c849456553a181ebefedb7606bac19aa1478b"},
+ {file = "greenlet-1.1.3.post0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:0a954002064ee919b444b19c1185e8cce307a1f20600f47d6f4b6d336972c809"},
+ {file = "greenlet-1.1.3.post0-cp39-cp39-win32.whl", hash = "sha256:2ccdc818cc106cc238ff7eba0d71b9c77be868fdca31d6c3b1347a54c9b187b2"},
+ {file = "greenlet-1.1.3.post0-cp39-cp39-win_amd64.whl", hash = "sha256:91a84faf718e6f8b888ca63d0b2d6d185c8e2a198d2a7322d75c303e7097c8b7"},
+ {file = "greenlet-1.1.3.post0.tar.gz", hash = "sha256:f5e09dc5c6e1796969fd4b775ea1417d70e49a5df29aaa8e5d10675d9e11872c"},
+]
+grpcio = [
+ {file = "grpcio-1.47.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:544da3458d1d249bb8aed5504adf3e194a931e212017934bf7bfa774dad37fb3"},
+ {file = "grpcio-1.47.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:b88bec3f94a16411a1e0336eb69f335f58229e45d4082b12d8e554cedea97586"},
+ {file = "grpcio-1.47.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:06c0739dff9e723bca28ec22301f3711d85c2e652d1c8ae938aa0f7ad632ef9a"},
+ {file = "grpcio-1.47.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f4508e8abd67ebcccd0fbde6e2b1917ba5d153f3f20c1de385abd8722545e05f"},
+ {file = "grpcio-1.47.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e9723784cf264697024778dcf4b7542c851fe14b14681d6268fb984a53f76df1"},
+ {file = "grpcio-1.47.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:1bb9afa85e797a646bfcd785309e869e80a375c959b11a17c9680abebacc0cb0"},
+ {file = "grpcio-1.47.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4d9ad7122f60157454f74a850d1337ba135146cef6fb7956d78c7194d52db0fe"},
+ {file = "grpcio-1.47.0-cp310-cp310-win32.whl", hash = "sha256:0425b5577be202d0a4024536bbccb1b052c47e0766096e6c3a5789ddfd5f400d"},
+ {file = "grpcio-1.47.0-cp310-cp310-win_amd64.whl", hash = "sha256:d0d481ff55ea6cc49dab2c8276597bd4f1a84a8745fedb4bc23e12e9fb9d0e45"},
+ {file = "grpcio-1.47.0-cp36-cp36m-linux_armv7l.whl", hash = "sha256:5f57b9b61c22537623a5577bf5f2f970dc4e50fac5391090114c6eb3ab5a129f"},
+ {file = "grpcio-1.47.0-cp36-cp36m-macosx_10_10_x86_64.whl", hash = "sha256:14d2bc74218986e5edf5527e870b0969d63601911994ebf0dce96288548cf0ef"},
+ {file = "grpcio-1.47.0-cp36-cp36m-manylinux_2_17_aarch64.whl", hash = "sha256:c79996ae64dc4d8730782dff0d1daacc8ce7d4c2ba9cef83b6f469f73c0655ce"},
+ {file = "grpcio-1.47.0-cp36-cp36m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a24b50810aae90c74bbd901c3f175b9645802d2fbf03eadaf418ddee4c26668"},
+ {file = "grpcio-1.47.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:55782a31ec539f15b34ee56f19131fe1430f38a4be022eb30c85e0b0dcf57f11"},
+ {file = "grpcio-1.47.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:35dfd981b03a3ec842671d1694fe437ee9f7b9e6a02792157a2793b0eba4f478"},
+ {file = "grpcio-1.47.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:664a270d3eac68183ad049665b0f4d0262ec387d5c08c0108dbcfe5b351a8b4d"},
+ {file = "grpcio-1.47.0-cp36-cp36m-win32.whl", hash = "sha256:9298d6f2a81f132f72a7e79cbc90a511fffacc75045c2b10050bb87b86c8353d"},
+ {file = "grpcio-1.47.0-cp36-cp36m-win_amd64.whl", hash = "sha256:815089435d0f113719eabf105832e4c4fa1726b39ae3fb2ca7861752b0f70570"},
+ {file = "grpcio-1.47.0-cp37-cp37m-linux_armv7l.whl", hash = "sha256:7191ffc8bcf8a630c547287ab103e1fdf72b2e0c119e634d8a36055c1d988ad0"},
+ {file = "grpcio-1.47.0-cp37-cp37m-macosx_10_10_x86_64.whl", hash = "sha256:1ec63bbd09586e5cda1bdc832ae6975d2526d04433a764a1cc866caa399e50d4"},
+ {file = "grpcio-1.47.0-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:08307dc5a6ac4da03146d6c00f62319e0665b01c6ffe805cfcaa955c17253f9c"},
+ {file = "grpcio-1.47.0-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:668350ea02af018ca945bd629754d47126b366d981ab88e0369b53bc781ffb14"},
+ {file = "grpcio-1.47.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:64e097dd08bb408afeeaee9a56f75311c9ca5b27b8b0278279dc8eef85fa1051"},
+ {file = "grpcio-1.47.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:0d8a7f3eb6f290189f48223a5f4464c99619a9de34200ce80d5092fb268323d2"},
+ {file = "grpcio-1.47.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:f89de64d9eb3478b188859214752db50c91a749479011abd99e248550371375f"},
+ {file = "grpcio-1.47.0-cp37-cp37m-win32.whl", hash = "sha256:67cd275a651532d28620eef677b97164a5438c5afcfd44b15e8992afa9eb598c"},
+ {file = "grpcio-1.47.0-cp37-cp37m-win_amd64.whl", hash = "sha256:f515782b168a4ec6ea241add845ccfebe187fc7b09adf892b3ad9e2592c60af1"},
+ {file = "grpcio-1.47.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:91cd292373e85a52c897fa5b4768c895e20a7dc3423449c64f0f96388dd1812e"},
+ {file = "grpcio-1.47.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:a278d02272214ec33f046864a24b5f5aab7f60f855de38c525e5b4ef61ec5b48"},
+ {file = "grpcio-1.47.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:bfdb8af4801d1c31a18d54b37f4e49bb268d1f485ecf47f70e78d56e04ff37a7"},
+ {file = "grpcio-1.47.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9e63e0619a5627edb7a5eb3e9568b9f97e604856ba228cc1d8a9f83ce3d0466e"},
+ {file = "grpcio-1.47.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cc34d182c4fd64b6ff8304a606b95e814e4f8ed4b245b6d6cc9607690e3ef201"},
+ {file = "grpcio-1.47.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a6b2432ac2353c80a56d9015dfc5c4af60245c719628d4193ecd75ddf9cd248c"},
+ {file = "grpcio-1.47.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fcd5d932842df503eb0bf60f9cc35e6fe732b51f499e78b45234e0be41b0018d"},
+ {file = "grpcio-1.47.0-cp38-cp38-win32.whl", hash = "sha256:43857d06b2473b640467467f8f553319b5e819e54be14c86324dad83a0547818"},
+ {file = "grpcio-1.47.0-cp38-cp38-win_amd64.whl", hash = "sha256:96cff5a2081db82fb710db6a19dd8f904bdebb927727aaf4d9c427984b79a4c1"},
+ {file = "grpcio-1.47.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:68b5e47fcca8481f36ef444842801928e60e30a5b3852c9f4a95f2582d10dcb2"},
+ {file = "grpcio-1.47.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:0cd44d78f302ff67f11a8c49b786c7ccbed2cfef6f4fd7bb0c3dc9255415f8f7"},
+ {file = "grpcio-1.47.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:4706c78b0c183dca815bbb4ef3e8dd2136ccc8d1699f62c585e75e211ad388f6"},
+ {file = "grpcio-1.47.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:324e363bad4d89a8ec7124013371f268d43afd0ac0fdeec1b21c1a101eb7dafb"},
+ {file = "grpcio-1.47.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b821403907e865e8377af3eee62f0cb233ea2369ba0fcdce9505ca5bfaf4eeb3"},
+ {file = "grpcio-1.47.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:2061dbe41e43b0a5e1fd423e8a7fb3a0cf11d69ce22d0fac21f1a8c704640b12"},
+ {file = "grpcio-1.47.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8dbef03853a0dbe457417c5469cb0f9d5bf47401b49d50c7dad3c495663b699b"},
+ {file = "grpcio-1.47.0-cp39-cp39-win32.whl", hash = "sha256:090dfa19f41efcbe760ae59b34da4304d4be9a59960c9682b7eab7e0b6748a79"},
+ {file = "grpcio-1.47.0-cp39-cp39-win_amd64.whl", hash = "sha256:55cd8b13c5ef22003889f599b8f2930836c6f71cd7cf3fc0196633813dc4f928"},
+ {file = "grpcio-1.47.0.tar.gz", hash = "sha256:5dbba95fab9b35957b4977b8904fc1fa56b302f9051eff4d7716ebb0c087f801"},
+]
+gunicorn = [
+ {file = "gunicorn-20.1.0-py3-none-any.whl", hash = "sha256:9dcc4547dbb1cb284accfb15ab5667a0e5d1881cc443e0677b4882a4067a807e"},
+ {file = "gunicorn-20.1.0.tar.gz", hash = "sha256:e0a968b5ba15f8a328fdfd7ab1fcb5af4470c28aaf7e55df02a99bc13138e6e8"},
+]
+hydra-core = [
+ {file = "hydra-core-1.2.0.tar.gz", hash = "sha256:4990721ce4ac69abafaffee566d6b63a54faa6501ecce65b338d3251446ff634"},
+ {file = "hydra_core-1.2.0-py3-none-any.whl", hash = "sha256:b6614fd6d6a97a9499f7ddbef02c9dd38f2fec6a9bc83c10e248db1dae50a528"},
+]
+hydra-submitit-launcher = [
+ {file = "hydra-submitit-launcher-1.2.0.tar.gz", hash = "sha256:e14c8eb46d020fac60ba25f82bcc368dc55851d2683dc95c88631ffcf15e4a34"},
+ {file = "hydra_submitit_launcher-1.2.0-py3-none-any.whl", hash = "sha256:51ce468fbc91211c3a46677eefde94bbb9f721c9545af0be6dd0a95658515613"},
+]
+hydra-zen = [
+ {file = "hydra_zen-0.7.1-py3-none-any.whl", hash = "sha256:231bd96d1368a62bc032465257dfd9a5f16c3d35a96db97e615ebb47cc4b29f2"},
+ {file = "hydra_zen-0.7.1.tar.gz", hash = "sha256:c3d9b88e504b140800d44f36e9a7163f83fe121cc9e5c0bafdd7feb71510b593"},
+]
+identify = [
+ {file = "identify-2.5.1-py2.py3-none-any.whl", hash = "sha256:0dca2ea3e4381c435ef9c33ba100a78a9b40c0bab11189c7cf121f75815efeaa"},
+ {file = "identify-2.5.1.tar.gz", hash = "sha256:3d11b16f3fe19f52039fb7e39c9c884b21cb1b586988114fbe42671f03de3e82"},
+]
+idna = [
+ {file = "idna-3.3-py3-none-any.whl", hash = "sha256:84d9dd047ffa80596e0f246e2eab0b391788b0503584e8945f2368256d2735ff"},
+ {file = "idna-3.3.tar.gz", hash = "sha256:9d643ff0a55b762d5cdb124b8eaa99c66322e2157b69160bc32796e824360e6d"},
+]
+imageio = [
+ {file = "imageio-2.19.3-py3-none-any.whl", hash = "sha256:d36ab8616175a093676693a4dbc85c6cc767f981c9ce93041422569c76d06347"},
+ {file = "imageio-2.19.3.tar.gz", hash = "sha256:0c9df80e42f2ee68bea92001e7fcf612aa149910efe040eb757f5ce323250ae1"},
+]
+imageio-ffmpeg = [
+ {file = "imageio-ffmpeg-0.4.7.tar.gz", hash = "sha256:7a08838f97f363e37ca41821b864fd3fdc99ab1fe2421040c78eb5f56a9e723e"},
+ {file = "imageio_ffmpeg-0.4.7-py3-none-macosx_10_9_intel.macosx_10_9_x86_64.macosx_10_10_intel.macosx_10_10_x86_64.whl", hash = "sha256:6514f1380daf42815bc8c83aad63f33e0b8b47133421ddafe7b410cd8dfbbea5"},
+ {file = "imageio_ffmpeg-0.4.7-py3-none-manylinux2010_x86_64.whl", hash = "sha256:27b48c32becae1658aa81c3a6b922538e4099edf5fbcbdb4ff5dbc84b8ffd3d3"},
+ {file = "imageio_ffmpeg-0.4.7-py3-none-manylinux2014_aarch64.whl", hash = "sha256:fc60686ef03c2d0f842901b206223c30051a6a120384458761390104470846fd"},
+ {file = "imageio_ffmpeg-0.4.7-py3-none-win32.whl", hash = "sha256:6aba52ddf0a64442ffcb8d30ac6afb668186acec99ecbc7ae5bd171c4f500bbc"},
+ {file = "imageio_ffmpeg-0.4.7-py3-none-win_amd64.whl", hash = "sha256:8e724d12dfe83e2a6eb39619e820243ca96c81c47c2648e66e05f7ee24e14312"},
+]
+importlib-metadata = [
+ {file = "importlib_metadata-4.2.0-py3-none-any.whl", hash = "sha256:057e92c15bc8d9e8109738a48db0ccb31b4d9d5cfbee5a8670879a30be66304b"},
+ {file = "importlib_metadata-4.2.0.tar.gz", hash = "sha256:b7e52a1f8dec14a75ea73e0891f3060099ca1d8e6a462a4dff11c3e119ea1b31"},
+]
+importlib-resources = [
+ {file = "importlib_resources-5.8.0-py3-none-any.whl", hash = "sha256:7952325ffd516c05a8ad0858c74dff2c3343f136fe66a6002b2623dd1d43f223"},
+ {file = "importlib_resources-5.8.0.tar.gz", hash = "sha256:568c9f16cb204f9decc8d6d24a572eeea27dacbb4cee9e6b03a8025736769751"},
+]
+iniconfig = [
+ {file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"},
+ {file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"},
+]
+ipykernel = [
+ {file = "ipykernel-6.16.2-py3-none-any.whl", hash = "sha256:67daf93e5b52456cd8eea87a8b59405d2bb80ae411864a1ea206c3631d8179af"},
+ {file = "ipykernel-6.16.2.tar.gz", hash = "sha256:463f3d87a92e99969b1605cb7a5b4d7b36b7145a0e72d06e65918a6ddefbe630"},
+]
+ipython = [
+ {file = "ipython-7.34.0-py3-none-any.whl", hash = "sha256:c175d2440a1caff76116eb719d40538fbb316e214eda85c5515c303aacbfb23e"},
+ {file = "ipython-7.34.0.tar.gz", hash = "sha256:af3bdb46aa292bce5615b1b2ebc76c2080c5f77f54bda2ec72461317273e7cd6"},
+]
+ipython_genutils = [
+ {file = "ipython_genutils-0.2.0-py2.py3-none-any.whl", hash = "sha256:72dd37233799e619666c9f639a9da83c34013a73e8bbc79a7a6348d93c61fab8"},
+ {file = "ipython_genutils-0.2.0.tar.gz", hash = "sha256:eb2e116e75ecef9d4d228fdc66af54269afa26ab4463042e33785b887c628ba8"},
+]
+ipywidgets = [
+ {file = "ipywidgets-8.0.6-py3-none-any.whl", hash = "sha256:a60bf8d2528997e05ac83fd19ea2fbe65f2e79fbe1b2b35779bdfc46c2941dcc"},
+ {file = "ipywidgets-8.0.6.tar.gz", hash = "sha256:de7d779f2045d60de9f6c25f653fdae2dba57898e6a1284494b3ba20b6893bb8"},
+]
+isort = [
+ {file = "isort-5.10.1-py3-none-any.whl", hash = "sha256:6f62d78e2f89b4500b080fe3a81690850cd254227f27f75c3a0c491a1f351ba7"},
+ {file = "isort-5.10.1.tar.gz", hash = "sha256:e8443a5e7a020e9d7f97f1d7d9cd17c88bcb3bc7e218bf9cf5095fe550be2951"},
+]
+itsdangerous = [
+ {file = "itsdangerous-2.1.2-py3-none-any.whl", hash = "sha256:2c2349112351b88699d8d4b6b075022c0808887cb7ad10069318a8b0bc88db44"},
+ {file = "itsdangerous-2.1.2.tar.gz", hash = "sha256:5dbbc68b317e5e42f327f9021763545dc3fc3bfe22e6deb96aaf1fc38874156a"},
+]
+jedi = [
+ {file = "jedi-0.18.2-py2.py3-none-any.whl", hash = "sha256:203c1fd9d969ab8f2119ec0a3342e0b49910045abe6af0a3ae83a5764d54639e"},
+ {file = "jedi-0.18.2.tar.gz", hash = "sha256:bae794c30d07f6d910d32a7048af09b5a39ed740918da923c6b780790ebac612"},
+]
+Jinja2 = [
+ {file = "Jinja2-3.1.2-py3-none-any.whl", hash = "sha256:6088930bfe239f0e6710546ab9c19c9ef35e29792895fed6e6e31a023a182a61"},
+ {file = "Jinja2-3.1.2.tar.gz", hash = "sha256:31351a702a408a9e7595a8fc6150fc3f43bb6bf7e319770cbc0db9df9437e852"},
+]
+jmespath = [
+ {file = "jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980"},
+ {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"},
+]
+joblib = [
+ {file = "joblib-1.1.0-py2.py3-none-any.whl", hash = "sha256:f21f109b3c7ff9d95f8387f752d0d9c34a02aa2f7060c2135f465da0e5160ff6"},
+ {file = "joblib-1.1.0.tar.gz", hash = "sha256:4158fcecd13733f8be669be0683b96ebdbbd38d23559f54dca7205aea1bf1e35"},
+]
+jsonschema = [
+ {file = "jsonschema-4.17.3-py3-none-any.whl", hash = "sha256:a870ad254da1a8ca84b6a2905cac29d265f805acc57af304784962a2aa6508f6"},
+ {file = "jsonschema-4.17.3.tar.gz", hash = "sha256:0f864437ab8b6076ba6707453ef8f98a6a0d512a80e93f8abdb676f737ecb60d"},
+]
+jupyter = [
+ {file = "jupyter-1.0.0-py2.py3-none-any.whl", hash = "sha256:5b290f93b98ffbc21c0c7e749f054b3267782166d72fa5e3ed1ed4eaf34a2b78"},
+ {file = "jupyter-1.0.0.tar.gz", hash = "sha256:d9dc4b3318f310e34c82951ea5d6683f67bed7def4b259fafbfe4f1beb1d8e5f"},
+ {file = "jupyter-1.0.0.zip", hash = "sha256:3e1f86076bbb7c8c207829390305a2b1fe836d471ed54be66a3b8c41e7f46cc7"},
+]
+jupyter-client = [
+ {file = "jupyter_client-7.4.9-py3-none-any.whl", hash = "sha256:214668aaea208195f4c13d28eb272ba79f945fc0cf3f11c7092c20b2ca1980e7"},
+ {file = "jupyter_client-7.4.9.tar.gz", hash = "sha256:52be28e04171f07aed8f20e1616a5a552ab9fee9cbbe6c1896ae170c3880d392"},
+]
+jupyter-console = [
+ {file = "jupyter_console-6.6.3-py3-none-any.whl", hash = "sha256:309d33409fcc92ffdad25f0bcdf9a4a9daa61b6f341177570fdac03de5352485"},
+ {file = "jupyter_console-6.6.3.tar.gz", hash = "sha256:566a4bf31c87adbfadf22cdf846e3069b59a71ed5da71d6ba4d8aaad14a53539"},
+]
+jupyter-core = [
+ {file = "jupyter_core-4.12.0-py3-none-any.whl", hash = "sha256:a54672c539333258495579f6964144924e0aa7b07f7069947bef76d7ea5cb4c1"},
+ {file = "jupyter_core-4.12.0.tar.gz", hash = "sha256:87f39d7642412ae8a52291cc68e71ac01dfa2c735df2701f8108251d51b4f460"},
+]
+jupyter-server = [
+ {file = "jupyter_server-1.24.0-py3-none-any.whl", hash = "sha256:c88ddbe862966ea1aea8c3ccb89a5903abd8fbcfe5cd14090ef549d403332c37"},
+ {file = "jupyter_server-1.24.0.tar.gz", hash = "sha256:23368e8e214baf82b313d4c5a0d828ca73015e1a192ce3829bd74e62fab8d046"},
+]
+jupyterlab-pygments = [
+ {file = "jupyterlab_pygments-0.2.2-py2.py3-none-any.whl", hash = "sha256:2405800db07c9f770863bcf8049a529c3dd4d3e28536638bd7c1c01d2748309f"},
+ {file = "jupyterlab_pygments-0.2.2.tar.gz", hash = "sha256:7405d7fde60819d905a9fa8ce89e4cd830e318cdad22a0030f7a901da705585d"},
+]
+jupyterlab-widgets = [
+ {file = "jupyterlab_widgets-3.0.7-py3-none-any.whl", hash = "sha256:c73f8370338ec19f1bec47254752d6505b03601cbd5a67e6a0b184532f73a459"},
+ {file = "jupyterlab_widgets-3.0.7.tar.gz", hash = "sha256:c3a50ed5bf528a0c7a869096503af54702f86dda1db469aee1c92dc0c01b43ca"},
+]
+kiwisolver = [
+ {file = "kiwisolver-1.4.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:fd2842a0faed9ab9aba0922c951906132d9384be89690570f0ed18cd4f20e658"},
+ {file = "kiwisolver-1.4.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:caa59e2cae0e23b1e225447d7a9ddb0f982f42a6a22d497a484dfe62a06f7c0e"},
+ {file = "kiwisolver-1.4.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1d2c744aeedce22c122bb42d176b4aa6d063202a05a4abdacb3e413c214b3694"},
+ {file = "kiwisolver-1.4.3-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:afe173ac2646c2636305ab820cc0380b22a00a7bca4290452e7166b4f4fa49d0"},
+ {file = "kiwisolver-1.4.3-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:40240da438c0ebfe2aa76dd04b844effac6679423df61adbe3437d32f23468d9"},
+ {file = "kiwisolver-1.4.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:21a3a98f0a21fc602663ca9bce2b12a4114891bdeba2dea1e9ad84db59892fca"},
+ {file = "kiwisolver-1.4.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:51078855a16b7a4984ed2067b54e35803d18bca9861cb60c60f6234b50869a56"},
+ {file = "kiwisolver-1.4.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c16635f8dddbeb1b827977d0b00d07b644b040aeb9ff8607a9fc0997afa3e567"},
+ {file = "kiwisolver-1.4.3-cp310-cp310-win32.whl", hash = "sha256:2d76780d9c65c7529cedd49fa4802d713e60798d8dc3b0d5b12a0a8f38cca51c"},
+ {file = "kiwisolver-1.4.3-cp310-cp310-win_amd64.whl", hash = "sha256:3a297d77b3d6979693f5948df02b89431ae3645ec95865e351fb45578031bdae"},
+ {file = "kiwisolver-1.4.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:ca3eefb02ef17257fae8b8555c85e7c1efdfd777f671384b0e4ef27409b02720"},
+ {file = "kiwisolver-1.4.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d248c46c0aa406695bda2abf99632db991f8b3a6d46018721a2892312a99f069"},
+ {file = "kiwisolver-1.4.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cb55258931448d61e2d50187de4ee66fc9d9f34908b524949b8b2b93d0c57136"},
+ {file = "kiwisolver-1.4.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:86bcf0009f2012847a688f2f4f9b16203ca4c835979a02549aa0595d9f457cc8"},
+ {file = "kiwisolver-1.4.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e7cf940af5fee00a92e281eb157abe8770227a5255207818ea9a34e54a29f5b2"},
+ {file = "kiwisolver-1.4.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:dd22085446f3eca990d12a0878eeb5199dc9553b2e71716bfe7bed9915a472ab"},
+ {file = "kiwisolver-1.4.3-cp37-cp37m-win32.whl", hash = "sha256:d2578e5149ff49878934debfacf5c743fab49eca5ecdb983d0b218e1e554c498"},
+ {file = "kiwisolver-1.4.3-cp37-cp37m-win_amd64.whl", hash = "sha256:5fb73cc8a34baba1dfa546ae83b9c248ef6150c238b06fc53d2773685b67ec67"},
+ {file = "kiwisolver-1.4.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:f70f3d028794e31cf9d1a822914efc935aadb2438ec4e8d4871d95eb1ce032d6"},
+ {file = "kiwisolver-1.4.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:71af5b43e4fa286a35110fc5bb740fdeae2b36ca79fbcf0a54237485baeee8be"},
+ {file = "kiwisolver-1.4.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:26b5a70bdab09e6a2f40babc4f8f992e3771751e144bda1938084c70d3001c09"},
+ {file = "kiwisolver-1.4.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1858ad3cb686eccc7c6b7c5eac846a1cfd45aacb5811b2cf575e80b208f5622a"},
+ {file = "kiwisolver-1.4.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4dc350cb65fe4e3f737d50f0465fa6ea0dcae0e5722b7edf5d5b0a0e3cd2c3c7"},
+ {file = "kiwisolver-1.4.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:007799c7fa934646318fc128b033bb6e6baabe7fbad521bfb2279aac26225cd7"},
+ {file = "kiwisolver-1.4.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:46fb56fde006b7ef5f8eaa3698299b0ea47444238b869ff3ced1426aa9fedcb5"},
+ {file = "kiwisolver-1.4.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:b9eb88593159a53a5ee0b0159daee531ff7dd9c87fa78f5d807ca059c7eb1b2b"},
+ {file = "kiwisolver-1.4.3-cp38-cp38-win32.whl", hash = "sha256:3b1dcbc49923ac3c973184a82832e1f018dec643b1e054867d04a3a22255ec6a"},
+ {file = "kiwisolver-1.4.3-cp38-cp38-win_amd64.whl", hash = "sha256:7118ca592d25b2957ff7b662bc0fe4f4c2b5d5b27814b9b1bc9f2fb249a970e7"},
+ {file = "kiwisolver-1.4.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:747190fcdadc377263223f8f72b038381b3b549a8a3df5baf4d067da4749b046"},
+ {file = "kiwisolver-1.4.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:fd628e63ffdba0112e3ddf1b1e9f3db29dd8262345138e08f4938acbc6d0805a"},
+ {file = "kiwisolver-1.4.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:22ccba48abae827a0f952a78a7b1a7ff01866131e5bbe1f826ce9bda406bf051"},
+ {file = "kiwisolver-1.4.3-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:af24b21c2283ca69c416a8a42cde9764dc36c63d3389645d28c69b0e93db3cd7"},
+ {file = "kiwisolver-1.4.3-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:547111ef7cf13d73546c2de97ce434935626c897bdec96a578ca100b5fcd694b"},
+ {file = "kiwisolver-1.4.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:84f85adfebd7d3c3db649efdf73659e1677a2cf3fa6e2556a3f373578af14bf7"},
+ {file = "kiwisolver-1.4.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ffd7cf165ff71afb202b3f36daafbf298932bee325aac9f58e1c9cd55838bef0"},
+ {file = "kiwisolver-1.4.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6b3136eecf7e1b4a4d23e4b19d6c4e7a8e0b42d55f30444e3c529700cdacaa0d"},
+ {file = "kiwisolver-1.4.3-cp39-cp39-win32.whl", hash = "sha256:46c6e5018ba31d5ee7582f323d8661498a154dea1117486a571db4c244531f24"},
+ {file = "kiwisolver-1.4.3-cp39-cp39-win_amd64.whl", hash = "sha256:8395064d63b26947fa2c9faeea9c3eee35e52148c5339c37987e1d96fbf009b3"},
+ {file = "kiwisolver-1.4.3-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:325fa1b15098e44fe4590a6c5c09a212ca10c6ebb5d96f7447d675f6c8340e4e"},
+ {file = "kiwisolver-1.4.3-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:654280c5f41831ddcc5a331c0e3ce2e480bbc3d7c93c18ecf6236313aae2d61a"},
+ {file = "kiwisolver-1.4.3-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1ae7aa0784aeadfbd693c27993727792fbe1455b84d49970bad5886b42976b18"},
+ {file = "kiwisolver-1.4.3-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:130c6c35eded399d3967cf8a542c20b671f5ba85bd6f210f8b939f868360e9eb"},
+ {file = "kiwisolver-1.4.3.tar.gz", hash = "sha256:ab8a15c2750ae8d53e31f77a94f846d0a00772240f1c12817411fa2344351f86"},
+]
+Mako = [
+ {file = "Mako-1.2.3-py3-none-any.whl", hash = "sha256:c413a086e38cd885088d5e165305ee8eed04e8b3f8f62df343480da0a385735f"},
+ {file = "Mako-1.2.3.tar.gz", hash = "sha256:7fde96466fcfeedb0eed94f187f20b23d85e4cb41444be0e542e2c8c65c396cd"},
+]
+markdown = [
+ {file = "Markdown-3.3.5-py3-none-any.whl", hash = "sha256:0d2d09f75cb8d1ffc6770c65c61770b23a61708101f47bda416a002a0edbc480"},
+ {file = "Markdown-3.3.5.tar.gz", hash = "sha256:26e9546bfbcde5fcd072bd8f612c9c1b6e2677cb8aadbdf65206674f46dde069"},
+]
+MarkupSafe = [
+ {file = "MarkupSafe-2.1.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:86b1f75c4e7c2ac2ccdaec2b9022845dbb81880ca318bb7a0a01fbf7813e3812"},
+ {file = "MarkupSafe-2.1.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f121a1420d4e173a5d96e47e9a0c0dcff965afdf1626d28de1460815f7c4ee7a"},
+ {file = "MarkupSafe-2.1.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a49907dd8420c5685cfa064a1335b6754b74541bbb3706c259c02ed65b644b3e"},
+ {file = "MarkupSafe-2.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10c1bfff05d95783da83491be968e8fe789263689c02724e0c691933c52994f5"},
+ {file = "MarkupSafe-2.1.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b7bd98b796e2b6553da7225aeb61f447f80a1ca64f41d83612e6139ca5213aa4"},
+ {file = "MarkupSafe-2.1.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:b09bf97215625a311f669476f44b8b318b075847b49316d3e28c08e41a7a573f"},
+ {file = "MarkupSafe-2.1.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:694deca8d702d5db21ec83983ce0bb4b26a578e71fbdbd4fdcd387daa90e4d5e"},
+ {file = "MarkupSafe-2.1.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:efc1913fd2ca4f334418481c7e595c00aad186563bbc1ec76067848c7ca0a933"},
+ {file = "MarkupSafe-2.1.1-cp310-cp310-win32.whl", hash = "sha256:4a33dea2b688b3190ee12bd7cfa29d39c9ed176bda40bfa11099a3ce5d3a7ac6"},
+ {file = "MarkupSafe-2.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:dda30ba7e87fbbb7eab1ec9f58678558fd9a6b8b853530e176eabd064da81417"},
+ {file = "MarkupSafe-2.1.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:671cd1187ed5e62818414afe79ed29da836dde67166a9fac6d435873c44fdd02"},
+ {file = "MarkupSafe-2.1.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3799351e2336dc91ea70b034983ee71cf2f9533cdff7c14c90ea126bfd95d65a"},
+ {file = "MarkupSafe-2.1.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e72591e9ecd94d7feb70c1cbd7be7b3ebea3f548870aa91e2732960fa4d57a37"},
+ {file = "MarkupSafe-2.1.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6fbf47b5d3728c6aea2abb0589b5d30459e369baa772e0f37a0320185e87c980"},
+ {file = "MarkupSafe-2.1.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d5ee4f386140395a2c818d149221149c54849dfcfcb9f1debfe07a8b8bd63f9a"},
+ {file = "MarkupSafe-2.1.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:bcb3ed405ed3222f9904899563d6fc492ff75cce56cba05e32eff40e6acbeaa3"},
+ {file = "MarkupSafe-2.1.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:e1c0b87e09fa55a220f058d1d49d3fb8df88fbfab58558f1198e08c1e1de842a"},
+ {file = "MarkupSafe-2.1.1-cp37-cp37m-win32.whl", hash = "sha256:8dc1c72a69aa7e082593c4a203dcf94ddb74bb5c8a731e4e1eb68d031e8498ff"},
+ {file = "MarkupSafe-2.1.1-cp37-cp37m-win_amd64.whl", hash = "sha256:97a68e6ada378df82bc9f16b800ab77cbf4b2fada0081794318520138c088e4a"},
+ {file = "MarkupSafe-2.1.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:e8c843bbcda3a2f1e3c2ab25913c80a3c5376cd00c6e8c4a86a89a28c8dc5452"},
+ {file = "MarkupSafe-2.1.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0212a68688482dc52b2d45013df70d169f542b7394fc744c02a57374a4207003"},
+ {file = "MarkupSafe-2.1.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e576a51ad59e4bfaac456023a78f6b5e6e7651dcd383bcc3e18d06f9b55d6d1"},
+ {file = "MarkupSafe-2.1.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b9fe39a2ccc108a4accc2676e77da025ce383c108593d65cc909add5c3bd601"},
+ {file = "MarkupSafe-2.1.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:96e37a3dc86e80bf81758c152fe66dbf60ed5eca3d26305edf01892257049925"},
+ {file = "MarkupSafe-2.1.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6d0072fea50feec76a4c418096652f2c3238eaa014b2f94aeb1d56a66b41403f"},
+ {file = "MarkupSafe-2.1.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:089cf3dbf0cd6c100f02945abeb18484bd1ee57a079aefd52cffd17fba910b88"},
+ {file = "MarkupSafe-2.1.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6a074d34ee7a5ce3effbc526b7083ec9731bb3cbf921bbe1d3005d4d2bdb3a63"},
+ {file = "MarkupSafe-2.1.1-cp38-cp38-win32.whl", hash = "sha256:421be9fbf0ffe9ffd7a378aafebbf6f4602d564d34be190fc19a193232fd12b1"},
+ {file = "MarkupSafe-2.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:fc7b548b17d238737688817ab67deebb30e8073c95749d55538ed473130ec0c7"},
+ {file = "MarkupSafe-2.1.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e04e26803c9c3851c931eac40c695602c6295b8d432cbe78609649ad9bd2da8a"},
+ {file = "MarkupSafe-2.1.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b87db4360013327109564f0e591bd2a3b318547bcef31b468a92ee504d07ae4f"},
+ {file = "MarkupSafe-2.1.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:99a2a507ed3ac881b975a2976d59f38c19386d128e7a9a18b7df6fff1fd4c1d6"},
+ {file = "MarkupSafe-2.1.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:56442863ed2b06d19c37f94d999035e15ee982988920e12a5b4ba29b62ad1f77"},
+ {file = "MarkupSafe-2.1.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3ce11ee3f23f79dbd06fb3d63e2f6af7b12db1d46932fe7bd8afa259a5996603"},
+ {file = "MarkupSafe-2.1.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:33b74d289bd2f5e527beadcaa3f401e0df0a89927c1559c8566c066fa4248ab7"},
+ {file = "MarkupSafe-2.1.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:43093fb83d8343aac0b1baa75516da6092f58f41200907ef92448ecab8825135"},
+ {file = "MarkupSafe-2.1.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8e3dcf21f367459434c18e71b2a9532d96547aef8a871872a5bd69a715c15f96"},
+ {file = "MarkupSafe-2.1.1-cp39-cp39-win32.whl", hash = "sha256:d4306c36ca495956b6d568d276ac11fdd9c30a36f1b6eb928070dc5360b22e1c"},
+ {file = "MarkupSafe-2.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:46d00d6cfecdde84d40e572d63735ef81423ad31184100411e6e3388d405e247"},
+ {file = "MarkupSafe-2.1.1.tar.gz", hash = "sha256:7f91197cc9e48f989d12e4e6fbc46495c446636dfc81b9ccf50bb0ec74b91d4b"},
+]
+matplotlib = [
+ {file = "matplotlib-3.5.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:03bbb3f5f78836855e127b5dab228d99551ad0642918ccbf3067fcd52ac7ac5e"},
+ {file = "matplotlib-3.5.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:49a5938ed6ef9dda560f26ea930a2baae11ea99e1c2080c8714341ecfda72a89"},
+ {file = "matplotlib-3.5.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:77157be0fc4469cbfb901270c205e7d8adb3607af23cef8bd11419600647ceed"},
+ {file = "matplotlib-3.5.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5844cea45d804174bf0fac219b4ab50774e504bef477fc10f8f730ce2d623441"},
+ {file = "matplotlib-3.5.2-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c87973ddec10812bddc6c286b88fdd654a666080fbe846a1f7a3b4ba7b11ab78"},
+ {file = "matplotlib-3.5.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4a05f2b37222319753a5d43c0a4fd97ed4ff15ab502113e3f2625c26728040cf"},
+ {file = "matplotlib-3.5.2-cp310-cp310-win32.whl", hash = "sha256:9776e1a10636ee5f06ca8efe0122c6de57ffe7e8c843e0fb6e001e9d9256ec95"},
+ {file = "matplotlib-3.5.2-cp310-cp310-win_amd64.whl", hash = "sha256:b4fedaa5a9aa9ce14001541812849ed1713112651295fdddd640ea6620e6cf98"},
+ {file = "matplotlib-3.5.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:ee175a571e692fc8ae8e41ac353c0e07259113f4cb063b0ec769eff9717e84bb"},
+ {file = "matplotlib-3.5.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e8bda1088b941ead50caabd682601bece983cadb2283cafff56e8fcddbf7d7f"},
+ {file = "matplotlib-3.5.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9480842d5aadb6e754f0b8f4ebeb73065ac8be1855baa93cd082e46e770591e9"},
+ {file = "matplotlib-3.5.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:6c623b355d605a81c661546af7f24414165a8a2022cddbe7380a31a4170fa2e9"},
+ {file = "matplotlib-3.5.2-cp37-cp37m-win32.whl", hash = "sha256:a91426ae910819383d337ba0dc7971c7cefdaa38599868476d94389a329e599b"},
+ {file = "matplotlib-3.5.2-cp37-cp37m-win_amd64.whl", hash = "sha256:c4b82c2ae6d305fcbeb0eb9c93df2602ebd2f174f6e8c8a5d92f9445baa0c1d3"},
+ {file = "matplotlib-3.5.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ebc27ad11df3c1661f4677a7762e57a8a91dd41b466c3605e90717c9a5f90c82"},
+ {file = "matplotlib-3.5.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5a32ea6e12e80dedaca2d4795d9ed40f97bfa56e6011e14f31502fdd528b9c89"},
+ {file = "matplotlib-3.5.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2a0967d4156adbd0d46db06bc1a877f0370bce28d10206a5071f9ecd6dc60b79"},
+ {file = "matplotlib-3.5.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2b696699386766ef171a259d72b203a3c75d99d03ec383b97fc2054f52e15cf"},
+ {file = "matplotlib-3.5.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7f409716119fa39b03da3d9602bd9b41142fab7a0568758cd136cd80b1bf36c8"},
+ {file = "matplotlib-3.5.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:b8d3f4e71e26307e8c120b72c16671d70c5cd08ae412355c11254aa8254fb87f"},
+ {file = "matplotlib-3.5.2-cp38-cp38-win32.whl", hash = "sha256:b6c63cd01cad0ea8704f1fd586e9dc5777ccedcd42f63cbbaa3eae8dd41172a1"},
+ {file = "matplotlib-3.5.2-cp38-cp38-win_amd64.whl", hash = "sha256:75c406c527a3aa07638689586343f4b344fcc7ab1f79c396699eb550cd2b91f7"},
+ {file = "matplotlib-3.5.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:4a44cdfdb9d1b2f18b1e7d315eb3843abb097869cd1ef89cfce6a488cd1b5182"},
+ {file = "matplotlib-3.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3d8e129af95b156b41cb3be0d9a7512cc6d73e2b2109f82108f566dbabdbf377"},
+ {file = "matplotlib-3.5.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:364e6bca34edc10a96aa3b1d7cd76eb2eea19a4097198c1b19e89bee47ed5781"},
+ {file = "matplotlib-3.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ea75df8e567743207e2b479ba3d8843537be1c146d4b1e3e395319a4e1a77fe9"},
+ {file = "matplotlib-3.5.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:44c6436868186564450df8fd2fc20ed9daaef5caad699aa04069e87099f9b5a8"},
+ {file = "matplotlib-3.5.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:7d7705022df2c42bb02937a2a824f4ec3cca915700dd80dc23916af47ff05f1a"},
+ {file = "matplotlib-3.5.2-cp39-cp39-win32.whl", hash = "sha256:ee0b8e586ac07f83bb2950717e66cb305e2859baf6f00a9c39cc576e0ce9629c"},
+ {file = "matplotlib-3.5.2-cp39-cp39-win_amd64.whl", hash = "sha256:c772264631e5ae61f0bd41313bbe48e1b9bcc95b974033e1118c9caa1a84d5c6"},
+ {file = "matplotlib-3.5.2-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:751d3815b555dcd6187ad35b21736dc12ce6925fc3fa363bbc6dc0f86f16484f"},
+ {file = "matplotlib-3.5.2-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:31fbc2af27ebb820763f077ec7adc79b5a031c2f3f7af446bd7909674cd59460"},
+ {file = "matplotlib-3.5.2-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:4fa28ca76ac5c2b2d54bc058b3dad8e22ee85d26d1ee1b116a6fd4d2277b6a04"},
+ {file = "matplotlib-3.5.2-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:24173c23d1bcbaed5bf47b8785d27933a1ac26a5d772200a0f3e0e38f471b001"},
+ {file = "matplotlib-3.5.2.tar.gz", hash = "sha256:48cf850ce14fa18067f2d9e0d646763681948487a8080ec0af2686468b4607a2"},
+]
+matplotlib-inline = [
+ {file = "matplotlib-inline-0.1.6.tar.gz", hash = "sha256:f887e5f10ba98e8d2b150ddcf4702c1e5f8b3a20005eb0f74bfdbd360ee6f304"},
+ {file = "matplotlib_inline-0.1.6-py3-none-any.whl", hash = "sha256:f1f41aab5328aa5aaea9b16d083b128102f8712542f819fe7e6a420ff581b311"},
+]
+mccabe = [
+ {file = "mccabe-0.6.1-py2.py3-none-any.whl", hash = "sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42"},
+ {file = "mccabe-0.6.1.tar.gz", hash = "sha256:dd8d182285a0fe56bace7f45b5e7d1a6ebcbf524e8f3bd87eb0f125271b8831f"},
+]
+mistune = [
+ {file = "mistune-2.0.5-py2.py3-none-any.whl", hash = "sha256:bad7f5d431886fcbaf5f758118ecff70d31f75231b34024a1341120340a65ce8"},
+ {file = "mistune-2.0.5.tar.gz", hash = "sha256:0246113cb2492db875c6be56974a7c893333bf26cd92891c85f63151cee09d34"},
+]
+mlflow = [
+ {file = "mlflow-1.29.0-py3-none-any.whl", hash = "sha256:24d95c6a19eccef5abfe5430680d96e9ab27c67f01cd4cde0f7384cb67a5c69a"},
+ {file = "mlflow-1.29.0.tar.gz", hash = "sha256:fad518600f515bc81cbf77053d506b769441229105b6a4bf8575feaa63a00da9"},
+]
+motmetrics = [
+ {file = "motmetrics-1.2.5-py3-none-any.whl", hash = "sha256:44052ccc7fa691df441ae420d39378f9173e31bdee8fb42474a58ea79f9f7c1c"},
+ {file = "motmetrics-1.2.5.tar.gz", hash = "sha256:3a777d5ab611cee008ae2c1acc39c7048d2b0b2eafed0f0f1ae473f35ebe34b9"},
+]
+moviepy = [
+ {file = "moviepy-1.0.3.tar.gz", hash = "sha256:2884e35d1788077db3ff89e763c5ba7bfddbd7ae9108c9bc809e7ba58fa433f5"},
+]
+multidict = [
+ {file = "multidict-6.0.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0b9e95a740109c6047602f4db4da9949e6c5945cefbad34a1299775ddc9a62e2"},
+ {file = "multidict-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ac0e27844758d7177989ce406acc6a83c16ed4524ebc363c1f748cba184d89d3"},
+ {file = "multidict-6.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:041b81a5f6b38244b34dc18c7b6aba91f9cdaf854d9a39e5ff0b58e2b5773b9c"},
+ {file = "multidict-6.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5fdda29a3c7e76a064f2477c9aab1ba96fd94e02e386f1e665bca1807fc5386f"},
+ {file = "multidict-6.0.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3368bf2398b0e0fcbf46d85795adc4c259299fec50c1416d0f77c0a843a3eed9"},
+ {file = "multidict-6.0.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f4f052ee022928d34fe1f4d2bc743f32609fb79ed9c49a1710a5ad6b2198db20"},
+ {file = "multidict-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:225383a6603c086e6cef0f2f05564acb4f4d5f019a4e3e983f572b8530f70c88"},
+ {file = "multidict-6.0.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:50bd442726e288e884f7be9071016c15a8742eb689a593a0cac49ea093eef0a7"},
+ {file = "multidict-6.0.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:47e6a7e923e9cada7c139531feac59448f1f47727a79076c0b1ee80274cd8eee"},
+ {file = "multidict-6.0.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:0556a1d4ea2d949efe5fd76a09b4a82e3a4a30700553a6725535098d8d9fb672"},
+ {file = "multidict-6.0.2-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:626fe10ac87851f4cffecee161fc6f8f9853f0f6f1035b59337a51d29ff3b4f9"},
+ {file = "multidict-6.0.2-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:8064b7c6f0af936a741ea1efd18690bacfbae4078c0c385d7c3f611d11f0cf87"},
+ {file = "multidict-6.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:2d36e929d7f6a16d4eb11b250719c39560dd70545356365b494249e2186bc389"},
+ {file = "multidict-6.0.2-cp310-cp310-win32.whl", hash = "sha256:fcb91630817aa8b9bc4a74023e4198480587269c272c58b3279875ed7235c293"},
+ {file = "multidict-6.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:8cbf0132f3de7cc6c6ce00147cc78e6439ea736cee6bca4f068bcf892b0fd658"},
+ {file = "multidict-6.0.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:05f6949d6169878a03e607a21e3b862eaf8e356590e8bdae4227eedadacf6e51"},
+ {file = "multidict-6.0.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2c2e459f7050aeb7c1b1276763364884595d47000c1cddb51764c0d8976e608"},
+ {file = "multidict-6.0.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d0509e469d48940147e1235d994cd849a8f8195e0bca65f8f5439c56e17872a3"},
+ {file = "multidict-6.0.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:514fe2b8d750d6cdb4712346a2c5084a80220821a3e91f3f71eec11cf8d28fd4"},
+ {file = "multidict-6.0.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:19adcfc2a7197cdc3987044e3f415168fc5dc1f720c932eb1ef4f71a2067e08b"},
+ {file = "multidict-6.0.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b9d153e7f1f9ba0b23ad1568b3b9e17301e23b042c23870f9ee0522dc5cc79e8"},
+ {file = "multidict-6.0.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:aef9cc3d9c7d63d924adac329c33835e0243b5052a6dfcbf7732a921c6e918ba"},
+ {file = "multidict-6.0.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:4571f1beddff25f3e925eea34268422622963cd8dc395bb8778eb28418248e43"},
+ {file = "multidict-6.0.2-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:d48b8ee1d4068561ce8033d2c344cf5232cb29ee1a0206a7b828c79cbc5982b8"},
+ {file = "multidict-6.0.2-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:45183c96ddf61bf96d2684d9fbaf6f3564d86b34cb125761f9a0ef9e36c1d55b"},
+ {file = "multidict-6.0.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:75bdf08716edde767b09e76829db8c1e5ca9d8bb0a8d4bd94ae1eafe3dac5e15"},
+ {file = "multidict-6.0.2-cp37-cp37m-win32.whl", hash = "sha256:a45e1135cb07086833ce969555df39149680e5471c04dfd6a915abd2fc3f6dbc"},
+ {file = "multidict-6.0.2-cp37-cp37m-win_amd64.whl", hash = "sha256:6f3cdef8a247d1eafa649085812f8a310e728bdf3900ff6c434eafb2d443b23a"},
+ {file = "multidict-6.0.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:0327292e745a880459ef71be14e709aaea2f783f3537588fb4ed09b6c01bca60"},
+ {file = "multidict-6.0.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e875b6086e325bab7e680e4316d667fc0e5e174bb5611eb16b3ea121c8951b86"},
+ {file = "multidict-6.0.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:feea820722e69451743a3d56ad74948b68bf456984d63c1a92e8347b7b88452d"},
+ {file = "multidict-6.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9cc57c68cb9139c7cd6fc39f211b02198e69fb90ce4bc4a094cf5fe0d20fd8b0"},
+ {file = "multidict-6.0.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:497988d6b6ec6ed6f87030ec03280b696ca47dbf0648045e4e1d28b80346560d"},
+ {file = "multidict-6.0.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:89171b2c769e03a953d5969b2f272efa931426355b6c0cb508022976a17fd376"},
+ {file = "multidict-6.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:684133b1e1fe91eda8fa7447f137c9490a064c6b7f392aa857bba83a28cfb693"},
+ {file = "multidict-6.0.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fd9fc9c4849a07f3635ccffa895d57abce554b467d611a5009ba4f39b78a8849"},
+ {file = "multidict-6.0.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:e07c8e79d6e6fd37b42f3250dba122053fddb319e84b55dd3a8d6446e1a7ee49"},
+ {file = "multidict-6.0.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:4070613ea2227da2bfb2c35a6041e4371b0af6b0be57f424fe2318b42a748516"},
+ {file = "multidict-6.0.2-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:47fbeedbf94bed6547d3aa632075d804867a352d86688c04e606971595460227"},
+ {file = "multidict-6.0.2-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:5774d9218d77befa7b70d836004a768fb9aa4fdb53c97498f4d8d3f67bb9cfa9"},
+ {file = "multidict-6.0.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:2957489cba47c2539a8eb7ab32ff49101439ccf78eab724c828c1a54ff3ff98d"},
+ {file = "multidict-6.0.2-cp38-cp38-win32.whl", hash = "sha256:e5b20e9599ba74391ca0cfbd7b328fcc20976823ba19bc573983a25b32e92b57"},
+ {file = "multidict-6.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:8004dca28e15b86d1b1372515f32eb6f814bdf6f00952699bdeb541691091f96"},
+ {file = "multidict-6.0.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:2e4a0785b84fb59e43c18a015ffc575ba93f7d1dbd272b4cdad9f5134b8a006c"},
+ {file = "multidict-6.0.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6701bf8a5d03a43375909ac91b6980aea74b0f5402fbe9428fc3f6edf5d9677e"},
+ {file = "multidict-6.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a007b1638e148c3cfb6bf0bdc4f82776cef0ac487191d093cdc316905e504071"},
+ {file = "multidict-6.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:07a017cfa00c9890011628eab2503bee5872f27144936a52eaab449be5eaf032"},
+ {file = "multidict-6.0.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c207fff63adcdf5a485969131dc70e4b194327666b7e8a87a97fbc4fd80a53b2"},
+ {file = "multidict-6.0.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:373ba9d1d061c76462d74e7de1c0c8e267e9791ee8cfefcf6b0b2495762c370c"},
+ {file = "multidict-6.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bfba7c6d5d7c9099ba21f84662b037a0ffd4a5e6b26ac07d19e423e6fdf965a9"},
+ {file = "multidict-6.0.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:19d9bad105dfb34eb539c97b132057a4e709919ec4dd883ece5838bcbf262b80"},
+ {file = "multidict-6.0.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:de989b195c3d636ba000ee4281cd03bb1234635b124bf4cd89eeee9ca8fcb09d"},
+ {file = "multidict-6.0.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7c40b7bbece294ae3a87c1bc2abff0ff9beef41d14188cda94ada7bcea99b0fb"},
+ {file = "multidict-6.0.2-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:d16cce709ebfadc91278a1c005e3c17dd5f71f5098bfae1035149785ea6e9c68"},
+ {file = "multidict-6.0.2-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:a2c34a93e1d2aa35fbf1485e5010337c72c6791407d03aa5f4eed920343dd360"},
+ {file = "multidict-6.0.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:feba80698173761cddd814fa22e88b0661e98cb810f9f986c54aa34d281e4937"},
+ {file = "multidict-6.0.2-cp39-cp39-win32.whl", hash = "sha256:23b616fdc3c74c9fe01d76ce0d1ce872d2d396d8fa8e4899398ad64fb5aa214a"},
+ {file = "multidict-6.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:4bae31803d708f6f15fd98be6a6ac0b6958fcf68fda3c77a048a4f9073704aae"},
+ {file = "multidict-6.0.2.tar.gz", hash = "sha256:5ff3bd75f38e4c43f1f470f2df7a4d430b821c4ce22be384e1459cb57d6bb013"},
+]
+mypy-extensions = [
+ {file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"},
+ {file = "mypy_extensions-0.4.3.tar.gz", hash = "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"},
+]
+nbclassic = [
+ {file = "nbclassic-0.5.5-py3-none-any.whl", hash = "sha256:47791b04dbcb89bf7fde910a3d848fd4793a4248a8936202453631a87da37d51"},
+ {file = "nbclassic-0.5.5.tar.gz", hash = "sha256:d2c91adc7909b0270c73e3e253d3687a6704b4f0a94bc156a37c85eba09f4d37"},
+]
+nbclient = [
+ {file = "nbclient-0.7.3-py3-none-any.whl", hash = "sha256:8fa96f7e36693d5e83408f5e840f113c14a45c279befe609904dbe05dad646d1"},
+ {file = "nbclient-0.7.3.tar.gz", hash = "sha256:26e41c6dca4d76701988bc34f64e1bfc2413ae6d368f13d7b5ac407efb08c755"},
+]
+nbconvert = [
+ {file = "nbconvert-7.3.1-py3-none-any.whl", hash = "sha256:d2e95904666f1ff77d36105b9de4e0801726f93b862d5b28f69e93d99ad3b19c"},
+ {file = "nbconvert-7.3.1.tar.gz", hash = "sha256:78685362b11d2e8058e70196fe83b09abed8df22d3e599cf271f4d39fdc48b9e"},
+]
+nbformat = [
+ {file = "nbformat-5.8.0-py3-none-any.whl", hash = "sha256:d910082bd3e0bffcf07eabf3683ed7dda0727a326c446eeb2922abe102e65162"},
+ {file = "nbformat-5.8.0.tar.gz", hash = "sha256:46dac64c781f1c34dfd8acba16547024110348f9fc7eab0f31981c2a3dc48d1f"},
+]
+nest-asyncio = [
+ {file = "nest_asyncio-1.5.6-py3-none-any.whl", hash = "sha256:b9a953fb40dceaa587d109609098db21900182b16440652454a146cffb06e8b8"},
+ {file = "nest_asyncio-1.5.6.tar.gz", hash = "sha256:d267cc1ff794403f7df692964d1d2a3fa9418ffea2a3f6859a439ff482fef290"},
+]
+nodeenv = [
+ {file = "nodeenv-1.7.0-py2.py3-none-any.whl", hash = "sha256:27083a7b96a25f2f5e1d8cb4b6317ee8aeda3bdd121394e5ac54e498028a042e"},
+ {file = "nodeenv-1.7.0.tar.gz", hash = "sha256:e0e7f7dfb85fc5394c6fe1e8fa98131a2473e04311a45afb6508f7cf1836fa2b"},
+]
+notebook = [
+ {file = "notebook-6.5.4-py3-none-any.whl", hash = "sha256:dd17e78aefe64c768737b32bf171c1c766666a21cc79a44d37a1700771cab56f"},
+ {file = "notebook-6.5.4.tar.gz", hash = "sha256:517209568bd47261e2def27a140e97d49070602eea0d226a696f42a7f16c9a4e"},
+]
+notebook-shim = [
+ {file = "notebook_shim-0.2.3-py3-none-any.whl", hash = "sha256:a83496a43341c1674b093bfcebf0fe8e74cbe7eda5fd2bbc56f8e39e1486c0c7"},
+ {file = "notebook_shim-0.2.3.tar.gz", hash = "sha256:f69388ac283ae008cd506dda10d0288b09a017d822d5e8c7129a152cbd3ce7e9"},
+]
+numpy = [
+ {file = "numpy-1.21.6-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:8737609c3bbdd48e380d463134a35ffad3b22dc56295eff6f79fd85bd0eeeb25"},
+ {file = "numpy-1.21.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:fdffbfb6832cd0b300995a2b08b8f6fa9f6e856d562800fea9182316d99c4e8e"},
+ {file = "numpy-1.21.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3820724272f9913b597ccd13a467cc492a0da6b05df26ea09e78b171a0bb9da6"},
+ {file = "numpy-1.21.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f17e562de9edf691a42ddb1eb4a5541c20dd3f9e65b09ded2beb0799c0cf29bb"},
+ {file = "numpy-1.21.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f30427731561ce75d7048ac254dbe47a2ba576229250fb60f0fb74db96501a1"},
+ {file = "numpy-1.21.6-cp310-cp310-win32.whl", hash = "sha256:d4bf4d43077db55589ffc9009c0ba0a94fa4908b9586d6ccce2e0b164c86303c"},
+ {file = "numpy-1.21.6-cp310-cp310-win_amd64.whl", hash = "sha256:d136337ae3cc69aa5e447e78d8e1514be8c3ec9b54264e680cf0b4bd9011574f"},
+ {file = "numpy-1.21.6-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6aaf96c7f8cebc220cdfc03f1d5a31952f027dda050e5a703a0d1c396075e3e7"},
+ {file = "numpy-1.21.6-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:67c261d6c0a9981820c3a149d255a76918278a6b03b6a036800359aba1256d46"},
+ {file = "numpy-1.21.6-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a6be4cb0ef3b8c9250c19cc122267263093eee7edd4e3fa75395dfda8c17a8e2"},
+ {file = "numpy-1.21.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c4068a8c44014b2d55f3c3f574c376b2494ca9cc73d2f1bd692382b6dffe3db"},
+ {file = "numpy-1.21.6-cp37-cp37m-win32.whl", hash = "sha256:7c7e5fa88d9ff656e067876e4736379cc962d185d5cd808014a8a928d529ef4e"},
+ {file = "numpy-1.21.6-cp37-cp37m-win_amd64.whl", hash = "sha256:bcb238c9c96c00d3085b264e5c1a1207672577b93fa666c3b14a45240b14123a"},
+ {file = "numpy-1.21.6-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:82691fda7c3f77c90e62da69ae60b5ac08e87e775b09813559f8901a88266552"},
+ {file = "numpy-1.21.6-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:643843bcc1c50526b3a71cd2ee561cf0d8773f062c8cbaf9ffac9fdf573f83ab"},
+ {file = "numpy-1.21.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:357768c2e4451ac241465157a3e929b265dfac85d9214074985b1786244f2ef3"},
+ {file = "numpy-1.21.6-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:9f411b2c3f3d76bba0865b35a425157c5dcf54937f82bbeb3d3c180789dd66a6"},
+ {file = "numpy-1.21.6-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:4aa48afdce4660b0076a00d80afa54e8a97cd49f457d68a4342d188a09451c1a"},
+ {file = "numpy-1.21.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d6a96eef20f639e6a97d23e57dd0c1b1069a7b4fd7027482a4c5c451cd7732f4"},
+ {file = "numpy-1.21.6-cp38-cp38-win32.whl", hash = "sha256:5c3c8def4230e1b959671eb959083661b4a0d2e9af93ee339c7dada6759a9470"},
+ {file = "numpy-1.21.6-cp38-cp38-win_amd64.whl", hash = "sha256:bf2ec4b75d0e9356edea834d1de42b31fe11f726a81dfb2c2112bc1eaa508fcf"},
+ {file = "numpy-1.21.6-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:4391bd07606be175aafd267ef9bea87cf1b8210c787666ce82073b05f202add1"},
+ {file = "numpy-1.21.6-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:67f21981ba2f9d7ba9ade60c9e8cbaa8cf8e9ae51673934480e45cf55e953673"},
+ {file = "numpy-1.21.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ee5ec40fdd06d62fe5d4084bef4fd50fd4bb6bfd2bf519365f569dc470163ab0"},
+ {file = "numpy-1.21.6-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:1dbe1c91269f880e364526649a52eff93ac30035507ae980d2fed33aaee633ac"},
+ {file = "numpy-1.21.6-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:d9caa9d5e682102453d96a0ee10c7241b72859b01a941a397fd965f23b3e016b"},
+ {file = "numpy-1.21.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58459d3bad03343ac4b1b42ed14d571b8743dc80ccbf27444f266729df1d6f5b"},
+ {file = "numpy-1.21.6-cp39-cp39-win32.whl", hash = "sha256:7f5ae4f304257569ef3b948810816bc87c9146e8c446053539947eedeaa32786"},
+ {file = "numpy-1.21.6-cp39-cp39-win_amd64.whl", hash = "sha256:e31f0bb5928b793169b87e3d1e070f2342b22d5245c755e2b81caa29756246c3"},
+ {file = "numpy-1.21.6-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:dd1c8f6bd65d07d3810b90d02eba7997e32abbdf1277a481d698969e921a3be0"},
+ {file = "numpy-1.21.6.zip", hash = "sha256:ecb55251139706669fdec2ff073c98ef8e9a84473e51e716211b41aa0f18e656"},
+]
+oauthlib = [
+ {file = "oauthlib-3.2.0-py3-none-any.whl", hash = "sha256:6db33440354787f9b7f3a6dbd4febf5d0f93758354060e802f6c06cb493022fe"},
+ {file = "oauthlib-3.2.0.tar.gz", hash = "sha256:23a8208d75b902797ea29fd31fa80a15ed9dc2c6c16fe73f5d346f83f6fa27a2"},
+]
+omegaconf = [
+ {file = "omegaconf-2.2.2-py3-none-any.whl", hash = "sha256:556917181487fb66fe832d3c7b324f51b2f4c8adc373dd5091be921501b7d420"},
+ {file = "omegaconf-2.2.2.tar.gz", hash = "sha256:65c85b2a84669a570c70f2df00de3cebcd9b47a8587d3c53b1aa5766bb096f77"},
+]
+packaging = [
+ {file = "packaging-21.3-py3-none-any.whl", hash = "sha256:ef103e05f519cdc783ae24ea4e2e0f508a9c99b2d4969652eed6a2e1ea5bd522"},
+ {file = "packaging-21.3.tar.gz", hash = "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb"},
+]
+pandas = [
+ {file = "pandas-1.3.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:62d5b5ce965bae78f12c1c0df0d387899dd4211ec0bdc52822373f13a3a022b9"},
+ {file = "pandas-1.3.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:adfeb11be2d54f275142c8ba9bf67acee771b7186a5745249c7d5a06c670136b"},
+ {file = "pandas-1.3.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:60a8c055d58873ad81cae290d974d13dd479b82cbb975c3e1fa2cf1920715296"},
+ {file = "pandas-1.3.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd541ab09e1f80a2a1760032d665f6e032d8e44055d602d65eeea6e6e85498cb"},
+ {file = "pandas-1.3.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2651d75b9a167cc8cc572cf787ab512d16e316ae00ba81874b560586fa1325e0"},
+ {file = "pandas-1.3.5-cp310-cp310-win_amd64.whl", hash = "sha256:aaf183a615ad790801fa3cf2fa450e5b6d23a54684fe386f7e3208f8b9bfbef6"},
+ {file = "pandas-1.3.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:344295811e67f8200de2390093aeb3c8309f5648951b684d8db7eee7d1c81fb7"},
+ {file = "pandas-1.3.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:552020bf83b7f9033b57cbae65589c01e7ef1544416122da0c79140c93288f56"},
+ {file = "pandas-1.3.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5cce0c6bbeb266b0e39e35176ee615ce3585233092f685b6a82362523e59e5b4"},
+ {file = "pandas-1.3.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7d28a3c65463fd0d0ba8bbb7696b23073efee0510783340a44b08f5e96ffce0c"},
+ {file = "pandas-1.3.5-cp37-cp37m-win32.whl", hash = "sha256:a62949c626dd0ef7de11de34b44c6475db76995c2064e2d99c6498c3dba7fe58"},
+ {file = "pandas-1.3.5-cp37-cp37m-win_amd64.whl", hash = "sha256:8025750767e138320b15ca16d70d5cdc1886e8f9cc56652d89735c016cd8aea6"},
+ {file = "pandas-1.3.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:fe95bae4e2d579812865db2212bb733144e34d0c6785c0685329e5b60fcb85dd"},
+ {file = "pandas-1.3.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5f261553a1e9c65b7a310302b9dbac31cf0049a51695c14ebe04e4bfd4a96f02"},
+ {file = "pandas-1.3.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b6dbec5f3e6d5dc80dcfee250e0a2a652b3f28663492f7dab9a24416a48ac39"},
+ {file = "pandas-1.3.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d3bc49af96cd6285030a64779de5b3688633a07eb75c124b0747134a63f4c05f"},
+ {file = "pandas-1.3.5-cp38-cp38-win32.whl", hash = "sha256:b6b87b2fb39e6383ca28e2829cddef1d9fc9e27e55ad91ca9c435572cdba51bf"},
+ {file = "pandas-1.3.5-cp38-cp38-win_amd64.whl", hash = "sha256:a395692046fd8ce1edb4c6295c35184ae0c2bbe787ecbe384251da609e27edcb"},
+ {file = "pandas-1.3.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bd971a3f08b745a75a86c00b97f3007c2ea175951286cdda6abe543e687e5f2f"},
+ {file = "pandas-1.3.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:37f06b59e5bc05711a518aa10beaec10942188dccb48918bb5ae602ccbc9f1a0"},
+ {file = "pandas-1.3.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c21778a688d3712d35710501f8001cdbf96eb70a7c587a3d5613573299fdca6"},
+ {file = "pandas-1.3.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3345343206546545bc26a05b4602b6a24385b5ec7c75cb6059599e3d56831da2"},
+ {file = "pandas-1.3.5-cp39-cp39-win32.whl", hash = "sha256:c69406a2808ba6cf580c2255bcf260b3f214d2664a3a4197d0e640f573b46fd3"},
+ {file = "pandas-1.3.5-cp39-cp39-win_amd64.whl", hash = "sha256:32e1a26d5ade11b547721a72f9bfc4bd113396947606e00d5b4a5b79b3dcb006"},
+ {file = "pandas-1.3.5.tar.gz", hash = "sha256:1e4285f5de1012de20ca46b188ccf33521bff61ba5c5ebd78b4fb28e5416a9f1"},
+]
+pandocfilters = [
+ {file = "pandocfilters-1.5.0-py2.py3-none-any.whl", hash = "sha256:33aae3f25fd1a026079f5d27bdd52496f0e0803b3469282162bafdcbdf6ef14f"},
+ {file = "pandocfilters-1.5.0.tar.gz", hash = "sha256:0b679503337d233b4339a817bfc8c50064e2eff681314376a47cb582305a7a38"},
+]
+parso = [
+ {file = "parso-0.8.3-py2.py3-none-any.whl", hash = "sha256:c001d4636cd3aecdaf33cbb40aebb59b094be2a74c556778ef5576c175e19e75"},
+ {file = "parso-0.8.3.tar.gz", hash = "sha256:8c07be290bb59f03588915921e29e8a50002acaf2cdc5fa0e0114f91709fafa0"},
+]
+pathspec = [
+ {file = "pathspec-0.9.0-py2.py3-none-any.whl", hash = "sha256:7d15c4ddb0b5c802d161efc417ec1a2558ea2653c2e8ad9c19098201dc1c993a"},
+ {file = "pathspec-0.9.0.tar.gz", hash = "sha256:e564499435a2673d586f6b2130bb5b95f04a3ba06f81b8f895b651a3c76aabb1"},
+]
+pexpect = [
+ {file = "pexpect-4.8.0-py2.py3-none-any.whl", hash = "sha256:0b48a55dcb3c05f3329815901ea4fc1537514d6ba867a152b581d69ae3710937"},
+ {file = "pexpect-4.8.0.tar.gz", hash = "sha256:fc65a43959d153d0114afe13997d439c22823a27cefceb5ff35c2178c6784c0c"},
+]
+pickleshare = [
+ {file = "pickleshare-0.7.5-py2.py3-none-any.whl", hash = "sha256:9649af414d74d4df115d5d718f82acb59c9d418196b7b4290ed47a12ce62df56"},
+ {file = "pickleshare-0.7.5.tar.gz", hash = "sha256:87683d47965c1da65cdacaf31c8441d12b8044cdec9aca500cd78fc2c683afca"},
+]
+pillow = [
+ {file = "Pillow-9.0.1-1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a5d24e1d674dd9d72c66ad3ea9131322819ff86250b30dc5821cbafcfa0b96b4"},
+ {file = "Pillow-9.0.1-1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2632d0f846b7c7600edf53c48f8f9f1e13e62f66a6dbc15191029d950bfed976"},
+ {file = "Pillow-9.0.1-1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b9618823bd237c0d2575283f2939655f54d51b4527ec3972907a927acbcc5bfc"},
+ {file = "Pillow-9.0.1-cp310-cp310-macosx_10_10_universal2.whl", hash = "sha256:9bfdb82cdfeccec50aad441afc332faf8606dfa5e8efd18a6692b5d6e79f00fd"},
+ {file = "Pillow-9.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5100b45a4638e3c00e4d2320d3193bdabb2d75e79793af7c3eb139e4f569f16f"},
+ {file = "Pillow-9.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:528a2a692c65dd5cafc130de286030af251d2ee0483a5bf50c9348aefe834e8a"},
+ {file = "Pillow-9.0.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0f29d831e2151e0b7b39981756d201f7108d3d215896212ffe2e992d06bfe049"},
+ {file = "Pillow-9.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:855c583f268edde09474b081e3ddcd5cf3b20c12f26e0d434e1386cc5d318e7a"},
+ {file = "Pillow-9.0.1-cp310-cp310-win32.whl", hash = "sha256:d9d7942b624b04b895cb95af03a23407f17646815495ce4547f0e60e0b06f58e"},
+ {file = "Pillow-9.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:81c4b81611e3a3cb30e59b0cf05b888c675f97e3adb2c8672c3154047980726b"},
+ {file = "Pillow-9.0.1-cp37-cp37m-macosx_10_10_x86_64.whl", hash = "sha256:413ce0bbf9fc6278b2d63309dfeefe452835e1c78398efb431bab0672fe9274e"},
+ {file = "Pillow-9.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:80fe64a6deb6fcfdf7b8386f2cf216d329be6f2781f7d90304351811fb591360"},
+ {file = "Pillow-9.0.1-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cef9c85ccbe9bee00909758936ea841ef12035296c748aaceee535969e27d31b"},
+ {file = "Pillow-9.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d19397351f73a88904ad1aee421e800fe4bbcd1aeee6435fb62d0a05ccd1030"},
+ {file = "Pillow-9.0.1-cp37-cp37m-win32.whl", hash = "sha256:d21237d0cd37acded35154e29aec853e945950321dd2ffd1a7d86fe686814669"},
+ {file = "Pillow-9.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:ede5af4a2702444a832a800b8eb7f0a7a1c0eed55b644642e049c98d589e5092"},
+ {file = "Pillow-9.0.1-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:b5b3f092fe345c03bca1e0b687dfbb39364b21ebb8ba90e3fa707374b7915204"},
+ {file = "Pillow-9.0.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:335ace1a22325395c4ea88e00ba3dc89ca029bd66bd5a3c382d53e44f0ccd77e"},
+ {file = "Pillow-9.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:db6d9fac65bd08cea7f3540b899977c6dee9edad959fa4eaf305940d9cbd861c"},
+ {file = "Pillow-9.0.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f154d173286a5d1863637a7dcd8c3437bb557520b01bddb0be0258dcb72696b5"},
+ {file = "Pillow-9.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:14d4b1341ac07ae07eb2cc682f459bec932a380c3b122f5540432d8977e64eae"},
+ {file = "Pillow-9.0.1-cp38-cp38-win32.whl", hash = "sha256:effb7749713d5317478bb3acb3f81d9d7c7f86726d41c1facca068a04cf5bb4c"},
+ {file = "Pillow-9.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:7f7609a718b177bf171ac93cea9fd2ddc0e03e84d8fa4e887bdfc39671d46b00"},
+ {file = "Pillow-9.0.1-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:80ca33961ced9c63358056bd08403ff866512038883e74f3a4bf88ad3eb66838"},
+ {file = "Pillow-9.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1c3c33ac69cf059bbb9d1a71eeaba76781b450bc307e2291f8a4764d779a6b28"},
+ {file = "Pillow-9.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:12875d118f21cf35604176872447cdb57b07126750a33748bac15e77f90f1f9c"},
+ {file = "Pillow-9.0.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:514ceac913076feefbeaf89771fd6febde78b0c4c1b23aaeab082c41c694e81b"},
+ {file = "Pillow-9.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d3c5c79ab7dfce6d88f1ba639b77e77a17ea33a01b07b99840d6ed08031cb2a7"},
+ {file = "Pillow-9.0.1-cp39-cp39-win32.whl", hash = "sha256:718856856ba31f14f13ba885ff13874be7fefc53984d2832458f12c38205f7f7"},
+ {file = "Pillow-9.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:f25ed6e28ddf50de7e7ea99d7a976d6a9c415f03adcaac9c41ff6ff41b6d86ac"},
+ {file = "Pillow-9.0.1-pp37-pypy37_pp73-macosx_10_10_x86_64.whl", hash = "sha256:011233e0c42a4a7836498e98c1acf5e744c96a67dd5032a6f666cc1fb97eab97"},
+ {file = "Pillow-9.0.1-pp37-pypy37_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:253e8a302a96df6927310a9d44e6103055e8fb96a6822f8b7f514bb7ef77de56"},
+ {file = "Pillow-9.0.1-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6295f6763749b89c994fcb6d8a7f7ce03c3992e695f89f00b741b4580b199b7e"},
+ {file = "Pillow-9.0.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:a9f44cd7e162ac6191491d7249cceb02b8116b0f7e847ee33f739d7cb1ea1f70"},
+ {file = "Pillow-9.0.1.tar.gz", hash = "sha256:6c8bc8238a7dfdaf7a75f5ec5a663f4173f8c367e5a39f87e720495e1eed75fa"},
+]
+pkgutil_resolve_name = [
+ {file = "pkgutil_resolve_name-1.3.10-py3-none-any.whl", hash = "sha256:ca27cc078d25c5ad71a9de0a7a330146c4e014c2462d9af19c6b828280649c5e"},
+ {file = "pkgutil_resolve_name-1.3.10.tar.gz", hash = "sha256:357d6c9e6a755653cfd78893817c0853af365dd51ec97f3d358a819373bbd174"},
+]
+platformdirs = [
+ {file = "platformdirs-2.5.2-py3-none-any.whl", hash = "sha256:027d8e83a2d7de06bbac4e5ef7e023c02b863d7ea5d079477e722bb41ab25788"},
+ {file = "platformdirs-2.5.2.tar.gz", hash = "sha256:58c8abb07dcb441e6ee4b11d8df0ac856038f944ab98b7be6b27b2a3c7feef19"},
+]
+pluggy = [
+ {file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"},
+ {file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"},
+]
+pre-commit = [
+ {file = "pre_commit-2.19.0-py2.py3-none-any.whl", hash = "sha256:10c62741aa5704faea2ad69cb550ca78082efe5697d6f04e5710c3c229afdd10"},
+ {file = "pre_commit-2.19.0.tar.gz", hash = "sha256:4233a1e38621c87d9dda9808c6606d7e7ba0e087cd56d3fe03202a01d2919615"},
+]
+proglog = [
+ {file = "proglog-0.1.10-py3-none-any.whl", hash = "sha256:19d5da037e8c813da480b741e3fa71fb1ac0a5b02bf21c41577c7f327485ec50"},
+ {file = "proglog-0.1.10.tar.gz", hash = "sha256:658c28c9c82e4caeb2f25f488fff9ceace22f8d69b15d0c1c86d64275e4ddab4"},
+]
+prometheus-client = [
+ {file = "prometheus_client-0.15.0-py3-none-any.whl", hash = "sha256:db7c05cbd13a0f79975592d112320f2605a325969b270a94b71dcabc47b931d2"},
+ {file = "prometheus_client-0.15.0.tar.gz", hash = "sha256:be26aa452490cfcf6da953f9436e95a9f2b4d578ca80094b4458930e5f584ab1"},
+]
+prometheus-flask-exporter = [
+ {file = "prometheus_flask_exporter-0.20.3-py3-none-any.whl", hash = "sha256:8e38ada61a24543c4ce65672db9694d0b3d0d20d4516e2f30d6ba85304cd6031"},
+ {file = "prometheus_flask_exporter-0.20.3.tar.gz", hash = "sha256:480ad73730e06ac6f6f45913595a588ce84811b6aaf11ff9532e530512e9d13d"},
+]
+prompt-toolkit = [
+ {file = "prompt_toolkit-3.0.38-py3-none-any.whl", hash = "sha256:45ea77a2f7c60418850331366c81cf6b5b9cf4c7fd34616f733c5427e6abbb1f"},
+ {file = "prompt_toolkit-3.0.38.tar.gz", hash = "sha256:23ac5d50538a9a38c8bde05fecb47d0b403ecd0662857a86f886f798563d5b9b"},
+]
+protobuf = [
+ {file = "protobuf-3.20.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3cc797c9d15d7689ed507b165cd05913acb992d78b379f6014e013f9ecb20996"},
+ {file = "protobuf-3.20.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:ff8d8fa42675249bb456f5db06c00de6c2f4c27a065955917b28c4f15978b9c3"},
+ {file = "protobuf-3.20.1-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:cd68be2559e2a3b84f517fb029ee611546f7812b1fdd0aa2ecc9bc6ec0e4fdde"},
+ {file = "protobuf-3.20.1-cp310-cp310-win32.whl", hash = "sha256:9016d01c91e8e625141d24ec1b20fed584703e527d28512aa8c8707f105a683c"},
+ {file = "protobuf-3.20.1-cp310-cp310-win_amd64.whl", hash = "sha256:32ca378605b41fd180dfe4e14d3226386d8d1b002ab31c969c366549e66a2bb7"},
+ {file = "protobuf-3.20.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:9be73ad47579abc26c12024239d3540e6b765182a91dbc88e23658ab71767153"},
+ {file = "protobuf-3.20.1-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:097c5d8a9808302fb0da7e20edf0b8d4703274d140fd25c5edabddcde43e081f"},
+ {file = "protobuf-3.20.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e250a42f15bf9d5b09fe1b293bdba2801cd520a9f5ea2d7fb7536d4441811d20"},
+ {file = "protobuf-3.20.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:cdee09140e1cd184ba9324ec1df410e7147242b94b5f8b0c64fc89e38a8ba531"},
+ {file = "protobuf-3.20.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:af0ebadc74e281a517141daad9d0f2c5d93ab78e9d455113719a45a49da9db4e"},
+ {file = "protobuf-3.20.1-cp37-cp37m-win32.whl", hash = "sha256:755f3aee41354ae395e104d62119cb223339a8f3276a0cd009ffabfcdd46bb0c"},
+ {file = "protobuf-3.20.1-cp37-cp37m-win_amd64.whl", hash = "sha256:62f1b5c4cd6c5402b4e2d63804ba49a327e0c386c99b1675c8a0fefda23b2067"},
+ {file = "protobuf-3.20.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:06059eb6953ff01e56a25cd02cca1a9649a75a7e65397b5b9b4e929ed71d10cf"},
+ {file = "protobuf-3.20.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:cb29edb9eab15742d791e1025dd7b6a8f6fcb53802ad2f6e3adcb102051063ab"},
+ {file = "protobuf-3.20.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:69ccfdf3657ba59569c64295b7d51325f91af586f8d5793b734260dfe2e94e2c"},
+ {file = "protobuf-3.20.1-cp38-cp38-win32.whl", hash = "sha256:dd5789b2948ca702c17027c84c2accb552fc30f4622a98ab5c51fcfe8c50d3e7"},
+ {file = "protobuf-3.20.1-cp38-cp38-win_amd64.whl", hash = "sha256:77053d28427a29987ca9caf7b72ccafee011257561259faba8dd308fda9a8739"},
+ {file = "protobuf-3.20.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6f50601512a3d23625d8a85b1638d914a0970f17920ff39cec63aaef80a93fb7"},
+ {file = "protobuf-3.20.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:284f86a6207c897542d7e956eb243a36bb8f9564c1742b253462386e96c6b78f"},
+ {file = "protobuf-3.20.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:7403941f6d0992d40161aa8bb23e12575637008a5a02283a930addc0508982f9"},
+ {file = "protobuf-3.20.1-cp39-cp39-win32.whl", hash = "sha256:db977c4ca738dd9ce508557d4fce0f5aebd105e158c725beec86feb1f6bc20d8"},
+ {file = "protobuf-3.20.1-cp39-cp39-win_amd64.whl", hash = "sha256:7e371f10abe57cee5021797126c93479f59fccc9693dafd6bd5633ab67808a91"},
+ {file = "protobuf-3.20.1-py2.py3-none-any.whl", hash = "sha256:adfc6cf69c7f8c50fd24c793964eef18f0ac321315439d94945820612849c388"},
+ {file = "protobuf-3.20.1.tar.gz", hash = "sha256:adc31566d027f45efe3f44eeb5b1f329da43891634d61c75a5944e9be6dd42c9"},
+]
+psutil = [
+ {file = "psutil-5.9.5-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:be8929ce4313f9f8146caad4272f6abb8bf99fc6cf59344a3167ecd74f4f203f"},
+ {file = "psutil-5.9.5-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:ab8ed1a1d77c95453db1ae00a3f9c50227ebd955437bcf2a574ba8adbf6a74d5"},
+ {file = "psutil-5.9.5-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:4aef137f3345082a3d3232187aeb4ac4ef959ba3d7c10c33dd73763fbc063da4"},
+ {file = "psutil-5.9.5-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:ea8518d152174e1249c4f2a1c89e3e6065941df2fa13a1ab45327716a23c2b48"},
+ {file = "psutil-5.9.5-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:acf2aef9391710afded549ff602b5887d7a2349831ae4c26be7c807c0a39fac4"},
+ {file = "psutil-5.9.5-cp27-none-win32.whl", hash = "sha256:5b9b8cb93f507e8dbaf22af6a2fd0ccbe8244bf30b1baad6b3954e935157ae3f"},
+ {file = "psutil-5.9.5-cp27-none-win_amd64.whl", hash = "sha256:8c5f7c5a052d1d567db4ddd231a9d27a74e8e4a9c3f44b1032762bd7b9fdcd42"},
+ {file = "psutil-5.9.5-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:3c6f686f4225553615612f6d9bc21f1c0e305f75d7d8454f9b46e901778e7217"},
+ {file = "psutil-5.9.5-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7a7dd9997128a0d928ed4fb2c2d57e5102bb6089027939f3b722f3a210f9a8da"},
+ {file = "psutil-5.9.5-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89518112647f1276b03ca97b65cc7f64ca587b1eb0278383017c2a0dcc26cbe4"},
+ {file = "psutil-5.9.5-cp36-abi3-win32.whl", hash = "sha256:104a5cc0e31baa2bcf67900be36acde157756b9c44017b86b2c049f11957887d"},
+ {file = "psutil-5.9.5-cp36-abi3-win_amd64.whl", hash = "sha256:b258c0c1c9d145a1d5ceffab1134441c4c5113b2417fafff7315a917a026c3c9"},
+ {file = "psutil-5.9.5-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:c607bb3b57dc779d55e1554846352b4e358c10fff3abf3514a7a6601beebdb30"},
+ {file = "psutil-5.9.5.tar.gz", hash = "sha256:5410638e4df39c54d957fc51ce03048acd8e6d60abc0f5107af51e5fb566eb3c"},
+]
+ptyprocess = [
+ {file = "ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35"},
+ {file = "ptyprocess-0.7.0.tar.gz", hash = "sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220"},
+]
+py = [
+ {file = "py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378"},
+ {file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"},
+]
+pyamg = [
+ {file = "pyamg-4.2.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:236c5caeb9c811b96ea3137ed63af1050150ba1639b2507e9c7dee318e65f944"},
+ {file = "pyamg-4.2.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:944c4f4cc97d798bbbdadaf86185e6911d5e930ddc38fb5ed2fe683228f4166c"},
+ {file = "pyamg-4.2.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:487206fdd40dbfe8efc14e7b12f551c1317b03412184a2a53d35ac51bd07b9be"},
+ {file = "pyamg-4.2.3-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:771ab589091934bfe6c803c4e55c9e99e79e2ec6e3705d0a80b2c4360d754040"},
+ {file = "pyamg-4.2.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:eba62a2c32d67128dcc5f583816310c3262714f83f94b6e2c9d3290fc8f055ff"},
+ {file = "pyamg-4.2.3-cp310-cp310-win32.whl", hash = "sha256:71887d6e1899dd4af24f6f232001feaf2ace21498bce6485333e20c36a7c1e2d"},
+ {file = "pyamg-4.2.3-cp310-cp310-win_amd64.whl", hash = "sha256:9fea7c57b0c9b9c04d6717f9ebe9607e6cfea8afeb5be127ff9c2fb76b6b0149"},
+ {file = "pyamg-4.2.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:4c38c615ed5c684051cfe9648dc1cb695ed45ff0ed84bb1417b1be9c7790d0d6"},
+ {file = "pyamg-4.2.3-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:7ead2de0e9f44e13e6da4fe2a3a1115c513b47eae1325002622f49d52114107f"},
+ {file = "pyamg-4.2.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:0172a7331136b6f0b5f017d9acf7939ddfdadd9304ce74817ab8036928125097"},
+ {file = "pyamg-4.2.3-cp37-cp37m-win32.whl", hash = "sha256:f091d6739cd3d46fc642f986be8f3dc0a72147e91da04ecef4f3636d857f5760"},
+ {file = "pyamg-4.2.3-cp37-cp37m-win_amd64.whl", hash = "sha256:96765421a6cb03b4c947d3e3ae3bb2dc01c744998914553b86fff635c122a54c"},
+ {file = "pyamg-4.2.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5c6c4105ad4c4cf4d4d6d2d8150f7910fce0bc7d155ddb2a049690ecea3f3975"},
+ {file = "pyamg-4.2.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1196a26277df0361621d46d7b8cc6b1d2de97f4448136db491edc83c8b9c9fe4"},
+ {file = "pyamg-4.2.3-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:d2c55a49353f8a8a71584caa05b196290111a0ccc7de98c784fe9f71848560c7"},
+ {file = "pyamg-4.2.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:dc2d11a01ce754e9167faa67857fe7bc920ff99b20e2fbc990e091e1aed3b148"},
+ {file = "pyamg-4.2.3-cp38-cp38-win32.whl", hash = "sha256:2d3274cd50f98416a4e85de958e9ea497914dbe1f7efa685ecf70b0f52c058aa"},
+ {file = "pyamg-4.2.3-cp38-cp38-win_amd64.whl", hash = "sha256:d42a316d834d8867422204011817b7d55774a646e94da4d7839cfaad5422f16f"},
+ {file = "pyamg-4.2.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:fc88ce72a5de1080266dfd9c8ada5b61e4af3d2276aa6b15532300596adbac97"},
+ {file = "pyamg-4.2.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3ced128e79742594121d38a411d86ba0b099ebc098f3428d4b5f5ce95dd6d318"},
+ {file = "pyamg-4.2.3-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:18af99d2551df07951c35cf270dc76703f8c5d30b16ea8e61657fda098f57dd7"},
+ {file = "pyamg-4.2.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c1827249145a8876eb954db88ba6be5f58a95254d6f46cdd86e0371468a2640e"},
+ {file = "pyamg-4.2.3-cp39-cp39-win32.whl", hash = "sha256:048c269554eb1212f6095283c633a5ae2a052824bb1efa9a62f88ddbe52f19d2"},
+ {file = "pyamg-4.2.3-cp39-cp39-win_amd64.whl", hash = "sha256:da73a026dbc93644918733cf42acc86cb2dc1dad3685afad5ace41eefe4f1119"},
+ {file = "pyamg-4.2.3.tar.gz", hash = "sha256:37ad3c1dcaff2435c2ab7c8ec36942af0726c563d35059bcd18eb07473f7182e"},
+]
+pyasn1 = [
+ {file = "pyasn1-0.4.8-py2.4.egg", hash = "sha256:fec3e9d8e36808a28efb59b489e4528c10ad0f480e57dcc32b4de5c9d8c9fdf3"},
+ {file = "pyasn1-0.4.8-py2.5.egg", hash = "sha256:0458773cfe65b153891ac249bcf1b5f8f320b7c2ce462151f8fa74de8934becf"},
+ {file = "pyasn1-0.4.8-py2.6.egg", hash = "sha256:5c9414dcfede6e441f7e8f81b43b34e834731003427e5b09e4e00e3172a10f00"},
+ {file = "pyasn1-0.4.8-py2.7.egg", hash = "sha256:6e7545f1a61025a4e58bb336952c5061697da694db1cae97b116e9c46abcf7c8"},
+ {file = "pyasn1-0.4.8-py2.py3-none-any.whl", hash = "sha256:39c7e2ec30515947ff4e87fb6f456dfc6e84857d34be479c9d4a4ba4bf46aa5d"},
+ {file = "pyasn1-0.4.8-py3.1.egg", hash = "sha256:78fa6da68ed2727915c4767bb386ab32cdba863caa7dbe473eaae45f9959da86"},
+ {file = "pyasn1-0.4.8-py3.2.egg", hash = "sha256:08c3c53b75eaa48d71cf8c710312316392ed40899cb34710d092e96745a358b7"},
+ {file = "pyasn1-0.4.8-py3.3.egg", hash = "sha256:03840c999ba71680a131cfaee6fab142e1ed9bbd9c693e285cc6aca0d555e576"},
+ {file = "pyasn1-0.4.8-py3.4.egg", hash = "sha256:7ab8a544af125fb704feadb008c99a88805126fb525280b2270bb25cc1d78a12"},
+ {file = "pyasn1-0.4.8-py3.5.egg", hash = "sha256:e89bf84b5437b532b0803ba5c9a5e054d21fec423a89952a74f87fa2c9b7bce2"},
+ {file = "pyasn1-0.4.8-py3.6.egg", hash = "sha256:014c0e9976956a08139dc0712ae195324a75e142284d5f87f1a87ee1b068a359"},
+ {file = "pyasn1-0.4.8-py3.7.egg", hash = "sha256:99fcc3c8d804d1bc6d9a099921e39d827026409a58f2a720dcdb89374ea0c776"},
+ {file = "pyasn1-0.4.8.tar.gz", hash = "sha256:aef77c9fb94a3ac588e87841208bdec464471d9871bd5050a287cc9a475cd0ba"},
+]
+pyasn1-modules = [
+ {file = "pyasn1-modules-0.2.8.tar.gz", hash = "sha256:905f84c712230b2c592c19470d3ca8d552de726050d1d1716282a1f6146be65e"},
+ {file = "pyasn1_modules-0.2.8-py2.4.egg", hash = "sha256:0fe1b68d1e486a1ed5473f1302bd991c1611d319bba158e98b106ff86e1d7199"},
+ {file = "pyasn1_modules-0.2.8-py2.5.egg", hash = "sha256:fe0644d9ab041506b62782e92b06b8c68cca799e1a9636ec398675459e031405"},
+ {file = "pyasn1_modules-0.2.8-py2.6.egg", hash = "sha256:a99324196732f53093a84c4369c996713eb8c89d360a496b599fb1a9c47fc3eb"},
+ {file = "pyasn1_modules-0.2.8-py2.7.egg", hash = "sha256:0845a5582f6a02bb3e1bde9ecfc4bfcae6ec3210dd270522fee602365430c3f8"},
+ {file = "pyasn1_modules-0.2.8-py2.py3-none-any.whl", hash = "sha256:a50b808ffeb97cb3601dd25981f6b016cbb3d31fbf57a8b8a87428e6158d0c74"},
+ {file = "pyasn1_modules-0.2.8-py3.1.egg", hash = "sha256:f39edd8c4ecaa4556e989147ebf219227e2cd2e8a43c7e7fcb1f1c18c5fd6a3d"},
+ {file = "pyasn1_modules-0.2.8-py3.2.egg", hash = "sha256:b80486a6c77252ea3a3e9b1e360bc9cf28eaac41263d173c032581ad2f20fe45"},
+ {file = "pyasn1_modules-0.2.8-py3.3.egg", hash = "sha256:65cebbaffc913f4fe9e4808735c95ea22d7a7775646ab690518c056784bc21b4"},
+ {file = "pyasn1_modules-0.2.8-py3.4.egg", hash = "sha256:15b7c67fabc7fc240d87fb9aabf999cf82311a6d6fb2c70d00d3d0604878c811"},
+ {file = "pyasn1_modules-0.2.8-py3.5.egg", hash = "sha256:426edb7a5e8879f1ec54a1864f16b882c2837bfd06eee62f2c982315ee2473ed"},
+ {file = "pyasn1_modules-0.2.8-py3.6.egg", hash = "sha256:cbac4bc38d117f2a49aeedec4407d23e8866ea4ac27ff2cf7fb3e5b570df19e0"},
+ {file = "pyasn1_modules-0.2.8-py3.7.egg", hash = "sha256:c29a5e5cc7a3f05926aff34e097e84f8589cd790ce0ed41b67aed6857b26aafd"},
+]
+pycodestyle = [
+ {file = "pycodestyle-2.8.0-py2.py3-none-any.whl", hash = "sha256:720f8b39dde8b293825e7ff02c475f3077124006db4f440dcbc9a20b76548a20"},
+ {file = "pycodestyle-2.8.0.tar.gz", hash = "sha256:eddd5847ef438ea1c7870ca7eb78a9d47ce0cdb4851a5523949f2601d0cbbe7f"},
+]
+pycparser = [
+ {file = "pycparser-2.21-py2.py3-none-any.whl", hash = "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9"},
+ {file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"},
+]
+pydeprecate = [
+ {file = "pyDeprecate-0.3.2-py3-none-any.whl", hash = "sha256:ed86b68ed837e6465245904a3de2f59bf9eef78ac7a2502ee280533d04802457"},
+ {file = "pyDeprecate-0.3.2.tar.gz", hash = "sha256:d481116cc5d7f6c473e7c4be820efdd9b90a16b594b350276e9e66a6cb5bdd29"},
+]
+pydocstyle = [
+ {file = "pydocstyle-6.1.1-py3-none-any.whl", hash = "sha256:6987826d6775056839940041beef5c08cc7e3d71d63149b48e36727f70144dc4"},
+ {file = "pydocstyle-6.1.1.tar.gz", hash = "sha256:1d41b7c459ba0ee6c345f2eb9ae827cab14a7533a88c5c6f7e94923f72df92dc"},
+]
+pyflakes = [
+ {file = "pyflakes-2.4.0-py2.py3-none-any.whl", hash = "sha256:3bb3a3f256f4b7968c9c788781e4ff07dce46bdf12339dcda61053375426ee2e"},
+ {file = "pyflakes-2.4.0.tar.gz", hash = "sha256:05a85c2872edf37a4ed30b0cce2f6093e1d0581f8c19d7393122da7e25b2b24c"},
+]
+Pygments = [
+ {file = "Pygments-2.15.1-py3-none-any.whl", hash = "sha256:db2db3deb4b4179f399a09054b023b6a586b76499d36965813c71aa8ed7b5fd1"},
+ {file = "Pygments-2.15.1.tar.gz", hash = "sha256:8ace4d3c1dd481894b2005f560ead0f9f19ee64fe983366be1a21e171d12775c"},
+]
+PyJWT = [
+ {file = "PyJWT-2.6.0-py3-none-any.whl", hash = "sha256:d83c3d892a77bbb74d3e1a2cfa90afaadb60945205d1095d9221f04466f64c14"},
+ {file = "PyJWT-2.6.0.tar.gz", hash = "sha256:69285c7e31fc44f68a1feb309e948e0df53259d579295e6cfe2b1792329f05fd"},
+]
+pyparsing = [
+ {file = "pyparsing-3.0.9-py3-none-any.whl", hash = "sha256:5026bae9a10eeaefb61dab2f09052b9f4307d44aee4eda64b309723d8d206bbc"},
+ {file = "pyparsing-3.0.9.tar.gz", hash = "sha256:2b020ecf7d21b687f219b71ecad3631f644a47f01403fa1d1036b0c6416d70fb"},
+]
+pyrsistent = [
+ {file = "pyrsistent-0.19.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:20460ac0ea439a3e79caa1dbd560344b64ed75e85d8703943e0b66c2a6150e4a"},
+ {file = "pyrsistent-0.19.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4c18264cb84b5e68e7085a43723f9e4c1fd1d935ab240ce02c0324a8e01ccb64"},
+ {file = "pyrsistent-0.19.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4b774f9288dda8d425adb6544e5903f1fb6c273ab3128a355c6b972b7df39dcf"},
+ {file = "pyrsistent-0.19.3-cp310-cp310-win32.whl", hash = "sha256:5a474fb80f5e0d6c9394d8db0fc19e90fa540b82ee52dba7d246a7791712f74a"},
+ {file = "pyrsistent-0.19.3-cp310-cp310-win_amd64.whl", hash = "sha256:49c32f216c17148695ca0e02a5c521e28a4ee6c5089f97e34fe24163113722da"},
+ {file = "pyrsistent-0.19.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f0774bf48631f3a20471dd7c5989657b639fd2d285b861237ea9e82c36a415a9"},
+ {file = "pyrsistent-0.19.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3ab2204234c0ecd8b9368dbd6a53e83c3d4f3cab10ecaf6d0e772f456c442393"},
+ {file = "pyrsistent-0.19.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e42296a09e83028b3476f7073fcb69ffebac0e66dbbfd1bd847d61f74db30f19"},
+ {file = "pyrsistent-0.19.3-cp311-cp311-win32.whl", hash = "sha256:64220c429e42a7150f4bfd280f6f4bb2850f95956bde93c6fda1b70507af6ef3"},
+ {file = "pyrsistent-0.19.3-cp311-cp311-win_amd64.whl", hash = "sha256:016ad1afadf318eb7911baa24b049909f7f3bb2c5b1ed7b6a8f21db21ea3faa8"},
+ {file = "pyrsistent-0.19.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c4db1bd596fefd66b296a3d5d943c94f4fac5bcd13e99bffe2ba6a759d959a28"},
+ {file = "pyrsistent-0.19.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aeda827381f5e5d65cced3024126529ddc4289d944f75e090572c77ceb19adbf"},
+ {file = "pyrsistent-0.19.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:42ac0b2f44607eb92ae88609eda931a4f0dfa03038c44c772e07f43e738bcac9"},
+ {file = "pyrsistent-0.19.3-cp37-cp37m-win32.whl", hash = "sha256:e8f2b814a3dc6225964fa03d8582c6e0b6650d68a232df41e3cc1b66a5d2f8d1"},
+ {file = "pyrsistent-0.19.3-cp37-cp37m-win_amd64.whl", hash = "sha256:c9bb60a40a0ab9aba40a59f68214eed5a29c6274c83b2cc206a359c4a89fa41b"},
+ {file = "pyrsistent-0.19.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:a2471f3f8693101975b1ff85ffd19bb7ca7dd7c38f8a81701f67d6b4f97b87d8"},
+ {file = "pyrsistent-0.19.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc5d149f31706762c1f8bda2e8c4f8fead6e80312e3692619a75301d3dbb819a"},
+ {file = "pyrsistent-0.19.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3311cb4237a341aa52ab8448c27e3a9931e2ee09561ad150ba94e4cfd3fc888c"},
+ {file = "pyrsistent-0.19.3-cp38-cp38-win32.whl", hash = "sha256:f0e7c4b2f77593871e918be000b96c8107da48444d57005b6a6bc61fb4331b2c"},
+ {file = "pyrsistent-0.19.3-cp38-cp38-win_amd64.whl", hash = "sha256:c147257a92374fde8498491f53ffa8f4822cd70c0d85037e09028e478cababb7"},
+ {file = "pyrsistent-0.19.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:b735e538f74ec31378f5a1e3886a26d2ca6351106b4dfde376a26fc32a044edc"},
+ {file = "pyrsistent-0.19.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:99abb85579e2165bd8522f0c0138864da97847875ecbd45f3e7e2af569bfc6f2"},
+ {file = "pyrsistent-0.19.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3a8cb235fa6d3fd7aae6a4f1429bbb1fec1577d978098da1252f0489937786f3"},
+ {file = "pyrsistent-0.19.3-cp39-cp39-win32.whl", hash = "sha256:c74bed51f9b41c48366a286395c67f4e894374306b197e62810e0fdaf2364da2"},
+ {file = "pyrsistent-0.19.3-cp39-cp39-win_amd64.whl", hash = "sha256:878433581fc23e906d947a6814336eee031a00e6defba224234169ae3d3d6a98"},
+ {file = "pyrsistent-0.19.3-py3-none-any.whl", hash = "sha256:ccf0d6bd208f8111179f0c26fdf84ed7c3891982f2edaeae7422575f47e66b64"},
+ {file = "pyrsistent-0.19.3.tar.gz", hash = "sha256:1a2994773706bbb4995c31a97bc94f1418314923bd1048c6d964837040376440"},
+]
+pytest = [
+ {file = "pytest-7.1.2-py3-none-any.whl", hash = "sha256:13d0e3ccfc2b6e26be000cb6568c832ba67ba32e719443bfe725814d3c42433c"},
+ {file = "pytest-7.1.2.tar.gz", hash = "sha256:a06a0425453864a270bc45e71f783330a7428defb4230fb5e6a731fde06ecd45"},
+]
+python-dateutil = [
+ {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"},
+ {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"},
+]
+pytorch-lightning = [
+ {file = "pytorch-lightning-1.6.4.tar.gz", hash = "sha256:5459f2c3e67676ec59e94576d1499e9559d214e7df41eadd135db64b4ccf54b9"},
+ {file = "pytorch_lightning-1.6.4-py3-none-any.whl", hash = "sha256:0f42f93116a3fcb6fd8c9ea45cf7c918e4aa3f848ae21d0e9ac2bf39f2865dd7"},
+]
+pytz = [
+ {file = "pytz-2022.4-py2.py3-none-any.whl", hash = "sha256:2c0784747071402c6e99f0bafdb7da0fa22645f06554c7ae06bf6358897e9c91"},
+ {file = "pytz-2022.4.tar.gz", hash = "sha256:48ce799d83b6f8aab2020e369b627446696619e79645419610b9facd909b3174"},
+]
+pywin32 = [
+ {file = "pywin32-304-cp310-cp310-win32.whl", hash = "sha256:3c7bacf5e24298c86314f03fa20e16558a4e4138fc34615d7de4070c23e65af3"},
+ {file = "pywin32-304-cp310-cp310-win_amd64.whl", hash = "sha256:4f32145913a2447736dad62495199a8e280a77a0ca662daa2332acf849f0be48"},
+ {file = "pywin32-304-cp310-cp310-win_arm64.whl", hash = "sha256:d3ee45adff48e0551d1aa60d2ec066fec006083b791f5c3527c40cd8aefac71f"},
+ {file = "pywin32-304-cp311-cp311-win32.whl", hash = "sha256:30c53d6ce44c12a316a06c153ea74152d3b1342610f1b99d40ba2795e5af0269"},
+ {file = "pywin32-304-cp311-cp311-win_amd64.whl", hash = "sha256:7ffa0c0fa4ae4077e8b8aa73800540ef8c24530057768c3ac57c609f99a14fd4"},
+ {file = "pywin32-304-cp311-cp311-win_arm64.whl", hash = "sha256:cbbe34dad39bdbaa2889a424d28752f1b4971939b14b1bb48cbf0182a3bcfc43"},
+ {file = "pywin32-304-cp36-cp36m-win32.whl", hash = "sha256:be253e7b14bc601718f014d2832e4c18a5b023cbe72db826da63df76b77507a1"},
+ {file = "pywin32-304-cp36-cp36m-win_amd64.whl", hash = "sha256:de9827c23321dcf43d2f288f09f3b6d772fee11e809015bdae9e69fe13213988"},
+ {file = "pywin32-304-cp37-cp37m-win32.whl", hash = "sha256:f64c0377cf01b61bd5e76c25e1480ca8ab3b73f0c4add50538d332afdf8f69c5"},
+ {file = "pywin32-304-cp37-cp37m-win_amd64.whl", hash = "sha256:bb2ea2aa81e96eee6a6b79d87e1d1648d3f8b87f9a64499e0b92b30d141e76df"},
+ {file = "pywin32-304-cp38-cp38-win32.whl", hash = "sha256:94037b5259701988954931333aafd39cf897e990852115656b014ce72e052e96"},
+ {file = "pywin32-304-cp38-cp38-win_amd64.whl", hash = "sha256:ead865a2e179b30fb717831f73cf4373401fc62fbc3455a0889a7ddac848f83e"},
+ {file = "pywin32-304-cp39-cp39-win32.whl", hash = "sha256:25746d841201fd9f96b648a248f731c1dec851c9a08b8e33da8b56148e4c65cc"},
+ {file = "pywin32-304-cp39-cp39-win_amd64.whl", hash = "sha256:d24a3382f013b21aa24a5cfbfad5a2cd9926610c0affde3e8ab5b3d7dbcf4ac9"},
+]
+pywinpty = [
+ {file = "pywinpty-2.0.10-cp310-none-win_amd64.whl", hash = "sha256:4c7d06ad10f6e92bc850a467f26d98f4f30e73d2fe5926536308c6ae0566bc16"},
+ {file = "pywinpty-2.0.10-cp311-none-win_amd64.whl", hash = "sha256:7ffbd66310b83e42028fc9df7746118978d94fba8c1ebf15a7c1275fdd80b28a"},
+ {file = "pywinpty-2.0.10-cp37-none-win_amd64.whl", hash = "sha256:38cb924f2778b5751ef91a75febd114776b3af0ae411bc667be45dd84fc881d3"},
+ {file = "pywinpty-2.0.10-cp38-none-win_amd64.whl", hash = "sha256:902d79444b29ad1833b8d5c3c9aabdfd428f4f068504430df18074007c8c0de8"},
+ {file = "pywinpty-2.0.10-cp39-none-win_amd64.whl", hash = "sha256:3c46aef80dd50979aff93de199e4a00a8ee033ba7a03cadf0a91fed45f0c39d7"},
+ {file = "pywinpty-2.0.10.tar.gz", hash = "sha256:cdbb5694cf8c7242c2ecfaca35c545d31fa5d5814c3d67a4e628f803f680ebea"},
+]
+pyyaml = [
+ {file = "PyYAML-5.4.1-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:3b2b1824fe7112845700f815ff6a489360226a5609b96ec2190a45e62a9fc922"},
+ {file = "PyYAML-5.4.1-cp27-cp27m-win32.whl", hash = "sha256:129def1b7c1bf22faffd67b8f3724645203b79d8f4cc81f674654d9902cb4393"},
+ {file = "PyYAML-5.4.1-cp27-cp27m-win_amd64.whl", hash = "sha256:4465124ef1b18d9ace298060f4eccc64b0850899ac4ac53294547536533800c8"},
+ {file = "PyYAML-5.4.1-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:bb4191dfc9306777bc594117aee052446b3fa88737cd13b7188d0e7aa8162185"},
+ {file = "PyYAML-5.4.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:6c78645d400265a062508ae399b60b8c167bf003db364ecb26dcab2bda048253"},
+ {file = "PyYAML-5.4.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:4e0583d24c881e14342eaf4ec5fbc97f934b999a6828693a99157fde912540cc"},
+ {file = "PyYAML-5.4.1-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:72a01f726a9c7851ca9bfad6fd09ca4e090a023c00945ea05ba1638c09dc3347"},
+ {file = "PyYAML-5.4.1-cp36-cp36m-manylinux2014_s390x.whl", hash = "sha256:895f61ef02e8fed38159bb70f7e100e00f471eae2bc838cd0f4ebb21e28f8541"},
+ {file = "PyYAML-5.4.1-cp36-cp36m-win32.whl", hash = "sha256:3bd0e463264cf257d1ffd2e40223b197271046d09dadf73a0fe82b9c1fc385a5"},
+ {file = "PyYAML-5.4.1-cp36-cp36m-win_amd64.whl", hash = "sha256:e4fac90784481d221a8e4b1162afa7c47ed953be40d31ab4629ae917510051df"},
+ {file = "PyYAML-5.4.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:5accb17103e43963b80e6f837831f38d314a0495500067cb25afab2e8d7a4018"},
+ {file = "PyYAML-5.4.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:e1d4970ea66be07ae37a3c2e48b5ec63f7ba6804bdddfdbd3cfd954d25a82e63"},
+ {file = "PyYAML-5.4.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:cb333c16912324fd5f769fff6bc5de372e9e7a202247b48870bc251ed40239aa"},
+ {file = "PyYAML-5.4.1-cp37-cp37m-manylinux2014_s390x.whl", hash = "sha256:fe69978f3f768926cfa37b867e3843918e012cf83f680806599ddce33c2c68b0"},
+ {file = "PyYAML-5.4.1-cp37-cp37m-win32.whl", hash = "sha256:dd5de0646207f053eb0d6c74ae45ba98c3395a571a2891858e87df7c9b9bd51b"},
+ {file = "PyYAML-5.4.1-cp37-cp37m-win_amd64.whl", hash = "sha256:08682f6b72c722394747bddaf0aa62277e02557c0fd1c42cb853016a38f8dedf"},
+ {file = "PyYAML-5.4.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d2d9808ea7b4af864f35ea216be506ecec180628aced0704e34aca0b040ffe46"},
+ {file = "PyYAML-5.4.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:8c1be557ee92a20f184922c7b6424e8ab6691788e6d86137c5d93c1a6ec1b8fb"},
+ {file = "PyYAML-5.4.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:fd7f6999a8070df521b6384004ef42833b9bd62cfee11a09bda1079b4b704247"},
+ {file = "PyYAML-5.4.1-cp38-cp38-manylinux2014_s390x.whl", hash = "sha256:bfb51918d4ff3d77c1c856a9699f8492c612cde32fd3bcd344af9be34999bfdc"},
+ {file = "PyYAML-5.4.1-cp38-cp38-win32.whl", hash = "sha256:fa5ae20527d8e831e8230cbffd9f8fe952815b2b7dae6ffec25318803a7528fc"},
+ {file = "PyYAML-5.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:0f5f5786c0e09baddcd8b4b45f20a7b5d61a7e7e99846e3c799b05c7c53fa696"},
+ {file = "PyYAML-5.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:294db365efa064d00b8d1ef65d8ea2c3426ac366c0c4368d930bf1c5fb497f77"},
+ {file = "PyYAML-5.4.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:74c1485f7707cf707a7aef42ef6322b8f97921bd89be2ab6317fd782c2d53183"},
+ {file = "PyYAML-5.4.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:d483ad4e639292c90170eb6f7783ad19490e7a8defb3e46f97dfe4bacae89122"},
+ {file = "PyYAML-5.4.1-cp39-cp39-manylinux2014_s390x.whl", hash = "sha256:fdc842473cd33f45ff6bce46aea678a54e3d21f1b61a7750ce3c498eedfe25d6"},
+ {file = "PyYAML-5.4.1-cp39-cp39-win32.whl", hash = "sha256:49d4cdd9065b9b6e206d0595fee27a96b5dd22618e7520c33204a4a3239d5b10"},
+ {file = "PyYAML-5.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:c20cfa2d49991c8b4147af39859b167664f2ad4561704ee74c1de03318e898db"},
+ {file = "PyYAML-5.4.1.tar.gz", hash = "sha256:607774cbba28732bfa802b54baa7484215f530991055bb562efbed5b2f20a45e"},
+]
+pyzmq = [
+ {file = "pyzmq-25.0.2-cp310-cp310-macosx_10_15_universal2.whl", hash = "sha256:ac178e666c097c8d3deb5097b58cd1316092fc43e8ef5b5fdb259b51da7e7315"},
+ {file = "pyzmq-25.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:659e62e1cbb063151c52f5b01a38e1df6b54feccfa3e2509d44c35ca6d7962ee"},
+ {file = "pyzmq-25.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8280ada89010735a12b968ec3ea9a468ac2e04fddcc1cede59cb7f5178783b9c"},
+ {file = "pyzmq-25.0.2-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a9b5eeb5278a8a636bb0abdd9ff5076bcbb836cd2302565df53ff1fa7d106d54"},
+ {file = "pyzmq-25.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a2e5fe42dfe6b73ca120b97ac9f34bfa8414feb15e00e37415dbd51cf227ef6"},
+ {file = "pyzmq-25.0.2-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:827bf60e749e78acb408a6c5af6688efbc9993e44ecc792b036ec2f4b4acf485"},
+ {file = "pyzmq-25.0.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:7b504ae43d37e282301da586529e2ded8b36d4ee2cd5e6db4386724ddeaa6bbc"},
+ {file = "pyzmq-25.0.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:cb1f69a0a2a2b1aae8412979dd6293cc6bcddd4439bf07e4758d864ddb112354"},
+ {file = "pyzmq-25.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:2b9c9cc965cdf28381e36da525dcb89fc1571d9c54800fdcd73e3f73a2fc29bd"},
+ {file = "pyzmq-25.0.2-cp310-cp310-win32.whl", hash = "sha256:24abbfdbb75ac5039205e72d6c75f10fc39d925f2df8ff21ebc74179488ebfca"},
+ {file = "pyzmq-25.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:6a821a506822fac55d2df2085a52530f68ab15ceed12d63539adc32bd4410f6e"},
+ {file = "pyzmq-25.0.2-cp311-cp311-macosx_10_15_universal2.whl", hash = "sha256:9af0bb0277e92f41af35e991c242c9c71920169d6aa53ade7e444f338f4c8128"},
+ {file = "pyzmq-25.0.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:54a96cf77684a3a537b76acfa7237b1e79a8f8d14e7f00e0171a94b346c5293e"},
+ {file = "pyzmq-25.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:88649b19ede1cab03b96b66c364cbbf17c953615cdbc844f7f6e5f14c5e5261c"},
+ {file = "pyzmq-25.0.2-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:715cff7644a80a7795953c11b067a75f16eb9fc695a5a53316891ebee7f3c9d5"},
+ {file = "pyzmq-25.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:312b3f0f066b4f1d17383aae509bacf833ccaf591184a1f3c7a1661c085063ae"},
+ {file = "pyzmq-25.0.2-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:d488c5c8630f7e782e800869f82744c3aca4aca62c63232e5d8c490d3d66956a"},
+ {file = "pyzmq-25.0.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:38d9f78d69bcdeec0c11e0feb3bc70f36f9b8c44fc06e5d06d91dc0a21b453c7"},
+ {file = "pyzmq-25.0.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:3059a6a534c910e1d5d068df42f60d434f79e6cc6285aa469b384fa921f78cf8"},
+ {file = "pyzmq-25.0.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:6526d097b75192f228c09d48420854d53dfbc7abbb41b0e26f363ccb26fbc177"},
+ {file = "pyzmq-25.0.2-cp311-cp311-win32.whl", hash = "sha256:5c5fbb229e40a89a2fe73d0c1181916f31e30f253cb2d6d91bea7927c2e18413"},
+ {file = "pyzmq-25.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:ed15e3a2c3c2398e6ae5ce86d6a31b452dfd6ad4cd5d312596b30929c4b6e182"},
+ {file = "pyzmq-25.0.2-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:032f5c8483c85bf9c9ca0593a11c7c749d734ce68d435e38c3f72e759b98b3c9"},
+ {file = "pyzmq-25.0.2-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:374b55516393bfd4d7a7daa6c3b36d6dd6a31ff9d2adad0838cd6a203125e714"},
+ {file = "pyzmq-25.0.2-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:08bfcc21b5997a9be4fefa405341320d8e7f19b4d684fb9c0580255c5bd6d695"},
+ {file = "pyzmq-25.0.2-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:1a843d26a8da1b752c74bc019c7b20e6791ee813cd6877449e6a1415589d22ff"},
+ {file = "pyzmq-25.0.2-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:b48616a09d7df9dbae2f45a0256eee7b794b903ddc6d8657a9948669b345f220"},
+ {file = "pyzmq-25.0.2-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:d4427b4a136e3b7f85516c76dd2e0756c22eec4026afb76ca1397152b0ca8145"},
+ {file = "pyzmq-25.0.2-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:26b0358e8933990502f4513c991c9935b6c06af01787a36d133b7c39b1df37fa"},
+ {file = "pyzmq-25.0.2-cp36-cp36m-win32.whl", hash = "sha256:c8fedc3ccd62c6b77dfe6f43802057a803a411ee96f14e946f4a76ec4ed0e117"},
+ {file = "pyzmq-25.0.2-cp36-cp36m-win_amd64.whl", hash = "sha256:2da6813b7995b6b1d1307329c73d3e3be2fd2d78e19acfc4eff2e27262732388"},
+ {file = "pyzmq-25.0.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:a35960c8b2f63e4ef67fd6731851030df68e4b617a6715dd11b4b10312d19fef"},
+ {file = "pyzmq-25.0.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eef2a0b880ab40aca5a878933376cb6c1ec483fba72f7f34e015c0f675c90b20"},
+ {file = "pyzmq-25.0.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:85762712b74c7bd18e340c3639d1bf2f23735a998d63f46bb6584d904b5e401d"},
+ {file = "pyzmq-25.0.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:64812f29d6eee565e129ca14b0c785744bfff679a4727137484101b34602d1a7"},
+ {file = "pyzmq-25.0.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:510d8e55b3a7cd13f8d3e9121edf0a8730b87d925d25298bace29a7e7bc82810"},
+ {file = "pyzmq-25.0.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:b164cc3c8acb3d102e311f2eb6f3c305865ecb377e56adc015cb51f721f1dda6"},
+ {file = "pyzmq-25.0.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:28fdb9224a258134784a9cf009b59265a9dde79582fb750d4e88a6bcbc6fa3dc"},
+ {file = "pyzmq-25.0.2-cp37-cp37m-win32.whl", hash = "sha256:dd771a440effa1c36d3523bc6ba4e54ff5d2e54b4adcc1e060d8f3ca3721d228"},
+ {file = "pyzmq-25.0.2-cp37-cp37m-win_amd64.whl", hash = "sha256:9bdc40efb679b9dcc39c06d25629e55581e4c4f7870a5e88db4f1c51ce25e20d"},
+ {file = "pyzmq-25.0.2-cp38-cp38-macosx_10_15_universal2.whl", hash = "sha256:1f82906a2d8e4ee310f30487b165e7cc8ed09c009e4502da67178b03083c4ce0"},
+ {file = "pyzmq-25.0.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:21ec0bf4831988af43c8d66ba3ccd81af2c5e793e1bf6790eb2d50e27b3c570a"},
+ {file = "pyzmq-25.0.2-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:abbce982a17c88d2312ec2cf7673985d444f1beaac6e8189424e0a0e0448dbb3"},
+ {file = "pyzmq-25.0.2-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:9e1d2f2d86fc75ed7f8845a992c5f6f1ab5db99747fb0d78b5e4046d041164d2"},
+ {file = "pyzmq-25.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a2e92ff20ad5d13266bc999a29ed29a3b5b101c21fdf4b2cf420c09db9fb690e"},
+ {file = "pyzmq-25.0.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:edbbf06cc2719889470a8d2bf5072bb00f423e12de0eb9ffec946c2c9748e149"},
+ {file = "pyzmq-25.0.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:77942243ff4d14d90c11b2afd8ee6c039b45a0be4e53fb6fa7f5e4fd0b59da39"},
+ {file = "pyzmq-25.0.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:ab046e9cb902d1f62c9cc0eca055b1d11108bdc271caf7c2171487298f229b56"},
+ {file = "pyzmq-25.0.2-cp38-cp38-win32.whl", hash = "sha256:ad761cfbe477236802a7ab2c080d268c95e784fe30cafa7e055aacd1ca877eb0"},
+ {file = "pyzmq-25.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:8560756318ec7c4c49d2c341012167e704b5a46d9034905853c3d1ade4f55bee"},
+ {file = "pyzmq-25.0.2-cp39-cp39-macosx_10_15_universal2.whl", hash = "sha256:ab2c056ac503f25a63f6c8c6771373e2a711b98b304614151dfb552d3d6c81f6"},
+ {file = "pyzmq-25.0.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:cca8524b61c0eaaa3505382dc9b9a3bc8165f1d6c010fdd1452c224225a26689"},
+ {file = "pyzmq-25.0.2-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:cfb9f7eae02d3ac42fbedad30006b7407c984a0eb4189a1322241a20944d61e5"},
+ {file = "pyzmq-25.0.2-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:5eaeae038c68748082137d6896d5c4db7927e9349237ded08ee1bbd94f7361c9"},
+ {file = "pyzmq-25.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a31992a8f8d51663ebf79df0df6a04ffb905063083d682d4380ab8d2c67257c"},
+ {file = "pyzmq-25.0.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:6a979e59d2184a0c8f2ede4b0810cbdd86b64d99d9cc8a023929e40dce7c86cc"},
+ {file = "pyzmq-25.0.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:1f124cb73f1aa6654d31b183810febc8505fd0c597afa127c4f40076be4574e0"},
+ {file = "pyzmq-25.0.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:65c19a63b4a83ae45d62178b70223adeee5f12f3032726b897431b6553aa25af"},
+ {file = "pyzmq-25.0.2-cp39-cp39-win32.whl", hash = "sha256:83d822e8687621bed87404afc1c03d83fa2ce39733d54c2fd52d8829edb8a7ff"},
+ {file = "pyzmq-25.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:24683285cc6b7bf18ad37d75b9db0e0fefe58404e7001f1d82bf9e721806daa7"},
+ {file = "pyzmq-25.0.2-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:4a4b4261eb8f9ed71f63b9eb0198dd7c934aa3b3972dac586d0ef502ba9ab08b"},
+ {file = "pyzmq-25.0.2-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:62ec8d979f56c0053a92b2b6a10ff54b9ec8a4f187db2b6ec31ee3dd6d3ca6e2"},
+ {file = "pyzmq-25.0.2-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:affec1470351178e892121b3414c8ef7803269f207bf9bef85f9a6dd11cde264"},
+ {file = "pyzmq-25.0.2-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffc71111433bd6ec8607a37b9211f4ef42e3d3b271c6d76c813669834764b248"},
+ {file = "pyzmq-25.0.2-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:6fadc60970714d86eff27821f8fb01f8328dd36bebd496b0564a500fe4a9e354"},
+ {file = "pyzmq-25.0.2-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:269968f2a76c0513490aeb3ba0dc3c77b7c7a11daa894f9d1da88d4a0db09835"},
+ {file = "pyzmq-25.0.2-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:f7c8b8368e84381ae7c57f1f5283b029c888504aaf4949c32e6e6fb256ec9bf0"},
+ {file = "pyzmq-25.0.2-pp38-pypy38_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:25e6873a70ad5aa31e4a7c41e5e8c709296edef4a92313e1cd5fc87bbd1874e2"},
+ {file = "pyzmq-25.0.2-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b733076ff46e7db5504c5e7284f04a9852c63214c74688bdb6135808531755a3"},
+ {file = "pyzmq-25.0.2-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:a6f6ae12478fdc26a6d5fdb21f806b08fa5403cd02fd312e4cb5f72df078f96f"},
+ {file = "pyzmq-25.0.2-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:67da1c213fbd208906ab3470cfff1ee0048838365135a9bddc7b40b11e6d6c89"},
+ {file = "pyzmq-25.0.2-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:531e36d9fcd66f18de27434a25b51d137eb546931033f392e85674c7a7cea853"},
+ {file = "pyzmq-25.0.2-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:34a6fddd159ff38aa9497b2e342a559f142ab365576284bc8f77cb3ead1f79c5"},
+ {file = "pyzmq-25.0.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b491998ef886662c1f3d49ea2198055a9a536ddf7430b051b21054f2a5831800"},
+ {file = "pyzmq-25.0.2-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:5d496815074e3e3d183fe2c7fcea2109ad67b74084c254481f87b64e04e9a471"},
+ {file = "pyzmq-25.0.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:56a94ab1d12af982b55ca96c6853db6ac85505e820d9458ac76364c1998972f4"},
+ {file = "pyzmq-25.0.2.tar.gz", hash = "sha256:6b8c1bbb70e868dc88801aa532cae6bd4e3b5233784692b786f17ad2962e5149"},
+]
+qtconsole = [
+ {file = "qtconsole-5.4.2-py3-none-any.whl", hash = "sha256:30975c6a7d7941dd646d00a23e5982db49beaa60c3920bb243727d43da805f12"},
+ {file = "qtconsole-5.4.2.tar.gz", hash = "sha256:dc935780da276a2ab31a7a08a8cf327a2ea47fa27e21d485073251a7eeb16167"},
+]
+QtPy = [
+ {file = "QtPy-2.3.1-py3-none-any.whl", hash = "sha256:5193d20e0b16e4d9d3bc2c642d04d9f4e2c892590bd1b9c92bfe38a95d5a2e12"},
+ {file = "QtPy-2.3.1.tar.gz", hash = "sha256:a8c74982d6d172ce124d80cafd39653df78989683f760f2281ba91a6e7b9de8b"},
+]
+querystring-parser = [
+ {file = "querystring_parser-1.2.4-py2.py3-none-any.whl", hash = "sha256:d2fa90765eaf0de96c8b087872991a10238e89ba015ae59fedfed6bd61c242a0"},
+ {file = "querystring_parser-1.2.4.tar.gz", hash = "sha256:644fce1cffe0530453b43a83a38094dbe422ccba8c9b2f2a1c00280e14ca8a62"},
+]
+regex = [
+ {file = "regex-2022.7.9-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d35bbcbf70d14f724e7489746cf68efe122796578addd98f91428e144d0ad266"},
+ {file = "regex-2022.7.9-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:12e1404dfb4e928d3273a10e3468877fe84bdcd3c50b655a2c9613cfc5d9fe63"},
+ {file = "regex-2022.7.9-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:727edff0a4eaff3b6d26cbb50216feac9055aba7e6290eec23c061c2fe2fab55"},
+ {file = "regex-2022.7.9-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:18e6203cfd81df42a987175aaeed7ba46bcb42130cd81763e2d5edcff0006d5d"},
+ {file = "regex-2022.7.9-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:06c509bd7dcb7966bdb03974457d548e54d8327bad5b0c917e87248edc43e2eb"},
+ {file = "regex-2022.7.9-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2f94b0befc811fe74a972b1739fffbf74c0dc1a91102aca8e324aa4f2c6991bd"},
+ {file = "regex-2022.7.9-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:121981ba84309dabefd5e1debd49be6d51624e54b4d44bfc184cd8d555ff1df1"},
+ {file = "regex-2022.7.9-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:d40b4447784dbe0896a6d10a178f6724598161f942c56f5a60dc0ef7fe63f7a1"},
+ {file = "regex-2022.7.9-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:414ae507ba88264444baf771fec43ce0adcd4c5dbb304d3e0716f3f4d4499d2e"},
+ {file = "regex-2022.7.9-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:9e4006942334fa954ebd32fa0728718ec870f95f4ba7cda9edc46dd49c294f22"},
+ {file = "regex-2022.7.9-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:0c1821146b429e6fdbd13ea10f26765e48d5284bc79749468cfbfe3ceb929f0d"},
+ {file = "regex-2022.7.9-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:402fa998c5988d11ed34585eb65740dcebd0fd11844d12eb0a6b4be178eb9c64"},
+ {file = "regex-2022.7.9-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d2672d68cf6c8452b6758fc3cd2d8feac966d511eed79a68182a5297b473af9c"},
+ {file = "regex-2022.7.9-cp310-cp310-win32.whl", hash = "sha256:2e5db20412f0db8798ff72473d16da5f13ec808e975b49188badb2462f529fa9"},
+ {file = "regex-2022.7.9-cp310-cp310-win_amd64.whl", hash = "sha256:667a06bb8d72b6da3d9cf38dac4ba969688868ed2279a692e993d2c0e1c30aba"},
+ {file = "regex-2022.7.9-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:b72a4ec79a15f6066d14ae1c472b743af4b4ecee14420e8d6e4a336b49b8f21c"},
+ {file = "regex-2022.7.9-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ea27acd97a752cfefa9907da935e583efecb302e6e9866f37565968c8407ad58"},
+ {file = "regex-2022.7.9-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:42da079e31ae9818ffa7a35cdd16ab7104e3f7eca9c0958040aede827b2e55c6"},
+ {file = "regex-2022.7.9-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ab1cb36b411f16da6e057ef8e6657dd0af36f59a667f07e0b4b617e44e53d7b2"},
+ {file = "regex-2022.7.9-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7f5ccfff648093152cadf6d886c7bd922047532f72024c953a79c7553aac2fe"},
+ {file = "regex-2022.7.9-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9daeccb2764bf4cc280c40c6411ae176bb0876948e536590a052b3d647254c95"},
+ {file = "regex-2022.7.9-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:673549a0136c7893f567ed71ab5225ed3701c79b17c0a7faee846c645fc24010"},
+ {file = "regex-2022.7.9-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:0fd8c3635fa03ef79d07c7b3ed693b3f3930ccb52c0c51761c3296a7525b135c"},
+ {file = "regex-2022.7.9-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:a048f91823862270905cb22ef88038b08aac852ce48e0ecc4b4bf1b895ec37d9"},
+ {file = "regex-2022.7.9-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:f8a2fd2f62a77536e4e3193303bec380df40d99e253b1c8f9b6eafa07eaeff67"},
+ {file = "regex-2022.7.9-cp36-cp36m-musllinux_1_1_s390x.whl", hash = "sha256:ab0709daedc1099bbd4371ae17eeedd4efc1cf70fcdcfe5de1374a0944b61f80"},
+ {file = "regex-2022.7.9-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:ee769a438827e443ed428e66d0aa7131c653ecd86ddc5d4644a81ed1d93af0e7"},
+ {file = "regex-2022.7.9-cp36-cp36m-win32.whl", hash = "sha256:e1fdda3ec7e9785065b67941693995cab95b54023a21db9bf39e54cc7b2c3526"},
+ {file = "regex-2022.7.9-cp36-cp36m-win_amd64.whl", hash = "sha256:00d2e907d3c5e4f85197c8d2263a9cc2d34bf234a9c6236ae42a3fb0bc09b759"},
+ {file = "regex-2022.7.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:5b1cffff2d9f832288fe516021cb81c95c57c0067b13a82f1d2daabdbc2f4270"},
+ {file = "regex-2022.7.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4cfeb71095c8d8380a5df5a38ff94d27a3f483717e509130a822b4d6400b7991"},
+ {file = "regex-2022.7.9-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0186edcda692c38381db8ac257c2d023fd2e08818d45dc5bee4ed84212045f51"},
+ {file = "regex-2022.7.9-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4dc74f0171eede67d79a79c06eca0fe5b7b280dbb8c27ad1fae4ced2ad66268f"},
+ {file = "regex-2022.7.9-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fbdf4fc6adf38fab1091c579ece3fe9f493bd0f1cfc3d2c76d2e52461ca4f8a9"},
+ {file = "regex-2022.7.9-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a00cd58a30a1041c193777cb1bc090200b05ff4b073d5935738afd1023e63069"},
+ {file = "regex-2022.7.9-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:d561dcb0fb0ab858291837d51330696a45fd3ba6912a332a4ee130e5484b9e47"},
+ {file = "regex-2022.7.9-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:ae1c5b435d44aa91d48cc710f20c3485e0584a3ad3565d5ae031d61a35f674f4"},
+ {file = "regex-2022.7.9-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:1703490c5b850fa9cef1af00c58966756042e6ca22f4fb5bb857345cd535834f"},
+ {file = "regex-2022.7.9-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:782627a1cb8fbb1c78d8e841f5b71c2c683086c038f975bebdac7cce7678a96f"},
+ {file = "regex-2022.7.9-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:b279b9bb401af41130fd2a09427105100bc8c624ed45da1c81c1c0d0aa639734"},
+ {file = "regex-2022.7.9-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:13d74951c14708f00700bb29475129ecbc40e01b4029c62ee7bfe9d1f59f31ce"},
+ {file = "regex-2022.7.9-cp37-cp37m-win32.whl", hash = "sha256:1244e9b9b4b81c9c34e8a84273ffaeebdc78abc98a5b02dcdd49845eb3c63bd7"},
+ {file = "regex-2022.7.9-cp37-cp37m-win_amd64.whl", hash = "sha256:67bd3bdd27db7a6460384869dd4b9c54267d805b67d70b20495bb5767f8e051c"},
+ {file = "regex-2022.7.9-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:34ae4f35db30caa4caf85c55069fcb7a05966a3a5ba6e9e1dab5477d84fbb08a"},
+ {file = "regex-2022.7.9-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f355caec5bbce20421dc26e53787b10e32fd0df68db2b795435217210c08d69c"},
+ {file = "regex-2022.7.9-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d93167b7d7731fa9ff9fdc1bae84ec9c7133b01a35f8cc04e926d48da6ce1f7"},
+ {file = "regex-2022.7.9-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8ab39aa445d00902c43a1e951871bedc7f18d095a21eccba153d594faac34aea"},
+ {file = "regex-2022.7.9-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1948d3ceac5b2d55bc93159c1e0679a256a87a54c735be5cef4543a9e692dbb9"},
+ {file = "regex-2022.7.9-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0220a7a16fd4bfc700661f920510defd31ef7830ce992d5cc51777aa8ccd724"},
+ {file = "regex-2022.7.9-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3c6df8be7d1dd35a0d9a200fbc29f888c4452c8882d284f87608046152e049e6"},
+ {file = "regex-2022.7.9-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:c7c5f914b0eb5242c09f91058b80295525897e873b522575ab235b48db125597"},
+ {file = "regex-2022.7.9-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:c2cd93725911c0159d597b90c96151070ef7e0e67604637e2f2abe06c34bf079"},
+ {file = "regex-2022.7.9-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:119091c675e6ad19da8770f89aa1d52f4ad2a2018d631956f3e90c45882df880"},
+ {file = "regex-2022.7.9-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:4c5913cb9769038bd03e42318955c2f15a688384a6a0b807bcfc8271603d9277"},
+ {file = "regex-2022.7.9-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:a3c47c71fde0c5d584402e67546c81af9951540f1f622d821e9c20761556473a"},
+ {file = "regex-2022.7.9-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:40a28759d345c0bb1f5b0ac74ac04f5d48136019522c95c0ec4b07786f67ce20"},
+ {file = "regex-2022.7.9-cp38-cp38-win32.whl", hash = "sha256:8e2075ed4ea2e231e2e98b16cfa5dae87e9a6045a71104525e1efc29aa8faa8e"},
+ {file = "regex-2022.7.9-cp38-cp38-win_amd64.whl", hash = "sha256:9f1c8fffd4def0b76c0947b8cb261b266e31041785dc2dc2db7569407a2f54fe"},
+ {file = "regex-2022.7.9-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:473a7d21932ce7c314953b33c32e63df690181860edcdf14bba1278cdf71b07f"},
+ {file = "regex-2022.7.9-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:192c2784833aea6fc7b004730bf1b91b8b8c6b998b30271aaf3bd8adfef20a96"},
+ {file = "regex-2022.7.9-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dce6b2ad817e3eb107f8704782b091b0631dd3adf47f14bdc086165d05b528b0"},
+ {file = "regex-2022.7.9-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e2fc1e3928c1189c0382c547c17717c6d9f425fffe619ef94270fe4c6c8be0a6"},
+ {file = "regex-2022.7.9-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f32e0d1c7e7b0b9c3cac76f3d278e7ee6b99c95672d2c1c6ea625033431837c0"},
+ {file = "regex-2022.7.9-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f87e9108bb532f8a1fc6bf7e69b930a35c7b0267b8fef0a3ede0bcb4c5aaa531"},
+ {file = "regex-2022.7.9-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:303676797c4c7978726e74eb8255d68f7125a3a29da71ff453448f2117290e9a"},
+ {file = "regex-2022.7.9-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a6d9ea727fd1233ee746bf44dd37e7d4320b3ed8ff09e73d7638c969b28d280f"},
+ {file = "regex-2022.7.9-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:7d462ba84655abeddae4dfc517fe1afefb5430b3b5acb0a954de12a47aea7183"},
+ {file = "regex-2022.7.9-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:e2a262ec85c595fc8e1f3162cafc654d2219125c00ea3a190c173cea70d2cc7a"},
+ {file = "regex-2022.7.9-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:162a5939a6fdf48658d3565eeff35acdd207e07367bf5caaff3d9ea7cb77d7a9"},
+ {file = "regex-2022.7.9-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:d07d849c9e2eca80adb85d3567302a47195a603ad7b1f0a07508e253c041f954"},
+ {file = "regex-2022.7.9-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:206a327e628bc529d64b21ff79a5e2564f5aec7dc7abcd4b2e8a4b271ec10550"},
+ {file = "regex-2022.7.9-cp39-cp39-win32.whl", hash = "sha256:49fcb45931a693b0e901972c5e077ea2cf30ec39da699645c43cb8b1542c6e14"},
+ {file = "regex-2022.7.9-cp39-cp39-win_amd64.whl", hash = "sha256:0a3f3f45c5902eb4d90266002ccb035531ae9b9278f6d5e8028247c34d192099"},
+ {file = "regex-2022.7.9.tar.gz", hash = "sha256:601c99ac775b6c89699a48976f3dbb000b47d3ca59362c8abc9582e6d0780d91"},
+]
+requests = [
+ {file = "requests-2.28.1-py3-none-any.whl", hash = "sha256:8fefa2a1a1365bf5520aac41836fbee479da67864514bdb821f31ce07ce65349"},
+ {file = "requests-2.28.1.tar.gz", hash = "sha256:7c5599b102feddaa661c826c56ab4fee28bfd17f5abca1ebbe3e7f19d7c97983"},
+]
+requests-oauthlib = [
+ {file = "requests-oauthlib-1.3.1.tar.gz", hash = "sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a"},
+ {file = "requests_oauthlib-1.3.1-py2.py3-none-any.whl", hash = "sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5"},
+]
+rsa = [
+ {file = "rsa-4.7.2-py3-none-any.whl", hash = "sha256:78f9a9bf4e7be0c5ded4583326e7461e3a3c5aae24073648b4bdfa797d78c9d2"},
+ {file = "rsa-4.7.2.tar.gz", hash = "sha256:9d689e6ca1b3038bc82bf8d23e944b6b6037bc02301a574935b2dd946e0353b9"},
+]
+s3transfer = [
+ {file = "s3transfer-0.6.0-py3-none-any.whl", hash = "sha256:06176b74f3a15f61f1b4f25a1fc29a4429040b7647133a463da8fa5bd28d5ecd"},
+ {file = "s3transfer-0.6.0.tar.gz", hash = "sha256:2ed07d3866f523cc561bf4a00fc5535827981b117dd7876f036b0c1aca42c947"},
+]
+scikit-learn = [
+ {file = "scikit-learn-1.0.2.tar.gz", hash = "sha256:b5870959a5484b614f26d31ca4c17524b1b0317522199dc985c3b4256e030767"},
+ {file = "scikit_learn-1.0.2-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:da3c84694ff693b5b3194d8752ccf935a665b8b5edc33a283122f4273ca3e687"},
+ {file = "scikit_learn-1.0.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:75307d9ea39236cad7eea87143155eea24d48f93f3a2f9389c817f7019f00705"},
+ {file = "scikit_learn-1.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f14517e174bd7332f1cca2c959e704696a5e0ba246eb8763e6c24876d8710049"},
+ {file = "scikit_learn-1.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d9aac97e57c196206179f674f09bc6bffcd0284e2ba95b7fe0b402ac3f986023"},
+ {file = "scikit_learn-1.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:d93d4c28370aea8a7cbf6015e8a669cd5d69f856cc2aa44e7a590fb805bb5583"},
+ {file = "scikit_learn-1.0.2-cp37-cp37m-macosx_10_13_x86_64.whl", hash = "sha256:85260fb430b795d806251dd3bb05e6f48cdc777ac31f2bcf2bc8bbed3270a8f5"},
+ {file = "scikit_learn-1.0.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:a053a6a527c87c5c4fa7bf1ab2556fa16d8345cf99b6c5a19030a4a7cd8fd2c0"},
+ {file = "scikit_learn-1.0.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:245c9b5a67445f6f044411e16a93a554edc1efdcce94d3fc0bc6a4b9ac30b752"},
+ {file = "scikit_learn-1.0.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:158faf30684c92a78e12da19c73feff9641a928a8024b4fa5ec11d583f3d8a87"},
+ {file = "scikit_learn-1.0.2-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:08ef968f6b72033c16c479c966bf37ccd49b06ea91b765e1cc27afefe723920b"},
+ {file = "scikit_learn-1.0.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:16455ace947d8d9e5391435c2977178d0ff03a261571e67f627c8fee0f9d431a"},
+ {file = "scikit_learn-1.0.2-cp37-cp37m-win32.whl", hash = "sha256:2f3b453e0b149898577e301d27e098dfe1a36943f7bb0ad704d1e548efc3b448"},
+ {file = "scikit_learn-1.0.2-cp37-cp37m-win_amd64.whl", hash = "sha256:46f431ec59dead665e1370314dbebc99ead05e1c0a9df42f22d6a0e00044820f"},
+ {file = "scikit_learn-1.0.2-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:ff3fa8ea0e09e38677762afc6e14cad77b5e125b0ea70c9bba1992f02c93b028"},
+ {file = "scikit_learn-1.0.2-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:9369b030e155f8188743eb4893ac17a27f81d28a884af460870c7c072f114243"},
+ {file = "scikit_learn-1.0.2-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:7d6b2475f1c23a698b48515217eb26b45a6598c7b1840ba23b3c5acece658dbb"},
+ {file = "scikit_learn-1.0.2-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:285db0352e635b9e3392b0b426bc48c3b485512d3b4ac3c7a44ec2a2ba061e66"},
+ {file = "scikit_learn-1.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5cb33fe1dc6f73dc19e67b264dbb5dde2a0539b986435fdd78ed978c14654830"},
+ {file = "scikit_learn-1.0.2-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b1391d1a6e2268485a63c3073111fe3ba6ec5145fc957481cfd0652be571226d"},
+ {file = "scikit_learn-1.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc3744dabc56b50bec73624aeca02e0def06b03cb287de26836e730659c5d29c"},
+ {file = "scikit_learn-1.0.2-cp38-cp38-win32.whl", hash = "sha256:a999c9f02ff9570c783069f1074f06fe7386ec65b84c983db5aeb8144356a355"},
+ {file = "scikit_learn-1.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:7626a34eabbf370a638f32d1a3ad50526844ba58d63e3ab81ba91e2a7c6d037e"},
+ {file = "scikit_learn-1.0.2-cp39-cp39-macosx_10_13_x86_64.whl", hash = "sha256:a90b60048f9ffdd962d2ad2fb16367a87ac34d76e02550968719eb7b5716fd10"},
+ {file = "scikit_learn-1.0.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:7a93c1292799620df90348800d5ac06f3794c1316ca247525fa31169f6d25855"},
+ {file = "scikit_learn-1.0.2-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:eabceab574f471de0b0eb3f2ecf2eee9f10b3106570481d007ed1c84ebf6d6a1"},
+ {file = "scikit_learn-1.0.2-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:55f2f3a8414e14fbee03782f9fe16cca0f141d639d2b1c1a36779fa069e1db57"},
+ {file = "scikit_learn-1.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:80095a1e4b93bd33261ef03b9bc86d6db649f988ea4dbcf7110d0cded8d7213d"},
+ {file = "scikit_learn-1.0.2-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fa38a1b9b38ae1fad2863eff5e0d69608567453fdfc850c992e6e47eb764e846"},
+ {file = "scikit_learn-1.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff746a69ff2ef25f62b36338c615dd15954ddc3ab8e73530237dd73235e76d62"},
+ {file = "scikit_learn-1.0.2-cp39-cp39-win32.whl", hash = "sha256:e174242caecb11e4abf169342641778f68e1bfaba80cd18acd6bc84286b9a534"},
+ {file = "scikit_learn-1.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:b54a62c6e318ddbfa7d22c383466d38d2ee770ebdb5ddb668d56a099f6eaf75f"},
+]
+scipy = [
+ {file = "scipy-1.7.3-1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:c9e04d7e9b03a8a6ac2045f7c5ef741be86727d8f49c45db45f244bdd2bcff17"},
+ {file = "scipy-1.7.3-1-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:b0e0aeb061a1d7dcd2ed59ea57ee56c9b23dd60100825f98238c06ee5cc4467e"},
+ {file = "scipy-1.7.3-1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:b78a35c5c74d336f42f44106174b9851c783184a85a3fe3e68857259b37b9ffb"},
+ {file = "scipy-1.7.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:173308efba2270dcd61cd45a30dfded6ec0085b4b6eb33b5eb11ab443005e088"},
+ {file = "scipy-1.7.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:21b66200cf44b1c3e86495e3a436fc7a26608f92b8d43d344457c54f1c024cbc"},
+ {file = "scipy-1.7.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ceebc3c4f6a109777c0053dfa0282fddb8893eddfb0d598574acfb734a926168"},
+ {file = "scipy-1.7.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7eaea089345a35130bc9a39b89ec1ff69c208efa97b3f8b25ea5d4c41d88094"},
+ {file = "scipy-1.7.3-cp310-cp310-win_amd64.whl", hash = "sha256:304dfaa7146cffdb75fbf6bb7c190fd7688795389ad060b970269c8576d038e9"},
+ {file = "scipy-1.7.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:033ce76ed4e9f62923e1f8124f7e2b0800db533828c853b402c7eec6e9465d80"},
+ {file = "scipy-1.7.3-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:4d242d13206ca4302d83d8a6388c9dfce49fc48fdd3c20efad89ba12f785bf9e"},
+ {file = "scipy-1.7.3-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:8499d9dd1459dc0d0fe68db0832c3d5fc1361ae8e13d05e6849b358dc3f2c279"},
+ {file = "scipy-1.7.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca36e7d9430f7481fc7d11e015ae16fbd5575615a8e9060538104778be84addf"},
+ {file = "scipy-1.7.3-cp37-cp37m-win32.whl", hash = "sha256:e2c036492e673aad1b7b0d0ccdc0cb30a968353d2c4bf92ac8e73509e1bf212c"},
+ {file = "scipy-1.7.3-cp37-cp37m-win_amd64.whl", hash = "sha256:866ada14a95b083dd727a845a764cf95dd13ba3dc69a16b99038001b05439709"},
+ {file = "scipy-1.7.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:65bd52bf55f9a1071398557394203d881384d27b9c2cad7df9a027170aeaef93"},
+ {file = "scipy-1.7.3-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:f99d206db1f1ae735a8192ab93bd6028f3a42f6fa08467d37a14eb96c9dd34a3"},
+ {file = "scipy-1.7.3-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:5f2cfc359379c56b3a41b17ebd024109b2049f878badc1e454f31418c3a18436"},
+ {file = "scipy-1.7.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eb7ae2c4dbdb3c9247e07acc532f91077ae6dbc40ad5bd5dca0bb5a176ee9bda"},
+ {file = "scipy-1.7.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95c2d250074cfa76715d58830579c64dff7354484b284c2b8b87e5a38321672c"},
+ {file = "scipy-1.7.3-cp38-cp38-win32.whl", hash = "sha256:87069cf875f0262a6e3187ab0f419f5b4280d3dcf4811ef9613c605f6e4dca95"},
+ {file = "scipy-1.7.3-cp38-cp38-win_amd64.whl", hash = "sha256:7edd9a311299a61e9919ea4192dd477395b50c014cdc1a1ac572d7c27e2207fa"},
+ {file = "scipy-1.7.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:eef93a446114ac0193a7b714ce67659db80caf940f3232bad63f4c7a81bc18df"},
+ {file = "scipy-1.7.3-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:eb326658f9b73c07081300daba90a8746543b5ea177184daed26528273157294"},
+ {file = "scipy-1.7.3-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:93378f3d14fff07572392ce6a6a2ceb3a1f237733bd6dcb9eb6a2b29b0d19085"},
+ {file = "scipy-1.7.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:edad1cf5b2ce1912c4d8ddad20e11d333165552aba262c882e28c78bbc09dbf6"},
+ {file = "scipy-1.7.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d1cc2c19afe3b5a546ede7e6a44ce1ff52e443d12b231823268019f608b9b12"},
+ {file = "scipy-1.7.3-cp39-cp39-win32.whl", hash = "sha256:2c56b820d304dffcadbbb6cbfbc2e2c79ee46ea291db17e288e73cd3c64fefa9"},
+ {file = "scipy-1.7.3-cp39-cp39-win_amd64.whl", hash = "sha256:3f78181a153fa21c018d346f595edd648344751d7f03ab94b398be2ad083ed3e"},
+ {file = "scipy-1.7.3.tar.gz", hash = "sha256:ab5875facfdef77e0a47d5fd39ea178b58e60e454a4c85aa1e52fcb80db7babf"},
+]
+Send2Trash = [
+ {file = "Send2Trash-1.8.0-py3-none-any.whl", hash = "sha256:f20eaadfdb517eaca5ce077640cb261c7d2698385a6a0f072a4a5447fd49fa08"},
+ {file = "Send2Trash-1.8.0.tar.gz", hash = "sha256:d2c24762fd3759860a0aff155e45871447ea58d2be6bdd39b5c8f966a0c99c2d"},
+]
+setuptools = [
+ {file = "setuptools-65.4.0-py3-none-any.whl", hash = "sha256:c2d2709550f15aab6c9110196ea312f468f41cd546bceb24127a1be6fdcaeeb1"},
+ {file = "setuptools-65.4.0.tar.gz", hash = "sha256:a8f6e213b4b0661f590ccf40de95d28a177cd747d098624ad3f69c40287297e9"},
+]
+setuptools-scm = [
+ {file = "setuptools_scm-7.0.4-py3-none-any.whl", hash = "sha256:53a6f51451a84d891ca485cec700a802413bbc5e76ee65da134e54c733a6e44d"},
+ {file = "setuptools_scm-7.0.4.tar.gz", hash = "sha256:c27bc1f48593cfc9527251f1f0fc41ce282ea57bbc7fd5a1ea3acb99325fab4c"},
+]
+six = [
+ {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"},
+ {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"},
+]
+smmap = [
+ {file = "smmap-5.0.0-py3-none-any.whl", hash = "sha256:2aba19d6a040e78d8b09de5c57e96207b09ed71d8e55ce0959eeee6c8e190d94"},
+ {file = "smmap-5.0.0.tar.gz", hash = "sha256:c840e62059cd3be204b0c9c9f74be2c09d5648eddd4580d9314c3ecde0b30936"},
+]
+sniffio = [
+ {file = "sniffio-1.3.0-py3-none-any.whl", hash = "sha256:eecefdce1e5bbfb7ad2eeaabf7c1eeb404d7757c379bd1f7e5cce9d8bf425384"},
+ {file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"},
+]
+snowballstemmer = [
+ {file = "snowballstemmer-2.2.0-py2.py3-none-any.whl", hash = "sha256:c8e1716e83cc398ae16824e5572ae04e0d9fc2c6b985fb0f900f5f0c96ecba1a"},
+ {file = "snowballstemmer-2.2.0.tar.gz", hash = "sha256:09b16deb8547d3412ad7b590689584cd0fe25ec8db3be37788be3810cbf19cb1"},
+]
+soupsieve = [
+ {file = "soupsieve-2.4.1-py3-none-any.whl", hash = "sha256:1c1bfee6819544a3447586c889157365a27e10d88cde3ad3da0cf0ddf646feb8"},
+ {file = "soupsieve-2.4.1.tar.gz", hash = "sha256:89d12b2d5dfcd2c9e8c22326da9d9aa9cb3dfab0a83a024f05704076ee8d35ea"},
+]
+SQLAlchemy = [
+ {file = "SQLAlchemy-1.4.42-cp27-cp27m-macosx_10_14_x86_64.whl", hash = "sha256:28e881266a172a4d3c5929182fde6bb6fba22ac93f137d5380cc78a11a9dd124"},
+ {file = "SQLAlchemy-1.4.42-cp27-cp27m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:ca9389a00f639383c93ed00333ed763812f80b5ae9e772ea32f627043f8c9c88"},
+ {file = "SQLAlchemy-1.4.42-cp27-cp27m-win32.whl", hash = "sha256:1d0c23ecf7b3bc81e29459c34a3f4c68ca538de01254e24718a7926810dc39a6"},
+ {file = "SQLAlchemy-1.4.42-cp27-cp27m-win_amd64.whl", hash = "sha256:6c9d004eb78c71dd4d3ce625b80c96a827d2e67af9c0d32b1c1e75992a7916cc"},
+ {file = "SQLAlchemy-1.4.42-cp27-cp27mu-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:9e3a65ce9ed250b2f096f7b559fe3ee92e6605fab3099b661f0397a9ac7c8d95"},
+ {file = "SQLAlchemy-1.4.42-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:2e56dfed0cc3e57b2f5c35719d64f4682ef26836b81067ee6cfad062290fd9e2"},
+ {file = "SQLAlchemy-1.4.42-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b42c59ffd2d625b28cdb2ae4cde8488543d428cba17ff672a543062f7caee525"},
+ {file = "SQLAlchemy-1.4.42-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:22459fc1718785d8a86171bbe7f01b5c9d7297301ac150f508d06e62a2b4e8d2"},
+ {file = "SQLAlchemy-1.4.42-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:df76e9c60879fdc785a34a82bf1e8691716ffac32e7790d31a98d7dec6e81545"},
+ {file = "SQLAlchemy-1.4.42-cp310-cp310-win32.whl", hash = "sha256:e7e740453f0149437c101ea4fdc7eea2689938c5760d7dcc436c863a12f1f565"},
+ {file = "SQLAlchemy-1.4.42-cp310-cp310-win_amd64.whl", hash = "sha256:effc89e606165ca55f04f3f24b86d3e1c605e534bf1a96e4e077ce1b027d0b71"},
+ {file = "SQLAlchemy-1.4.42-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:97ff50cd85bb907c2a14afb50157d0d5486a4b4639976b4a3346f34b6d1b5272"},
+ {file = "SQLAlchemy-1.4.42-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e12c6949bae10f1012ab5c0ea52ab8db99adcb8c7b717938252137cdf694c775"},
+ {file = "SQLAlchemy-1.4.42-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:11b2ec26c5d2eefbc3e6dca4ec3d3d95028be62320b96d687b6e740424f83b7d"},
+ {file = "SQLAlchemy-1.4.42-cp311-cp311-win32.whl", hash = "sha256:6045b3089195bc008aee5c273ec3ba9a93f6a55bc1b288841bd4cfac729b6516"},
+ {file = "SQLAlchemy-1.4.42-cp311-cp311-win_amd64.whl", hash = "sha256:0501f74dd2745ec38f44c3a3900fb38b9db1ce21586b691482a19134062bf049"},
+ {file = "SQLAlchemy-1.4.42-cp36-cp36m-macosx_10_14_x86_64.whl", hash = "sha256:6e39e97102f8e26c6c8550cb368c724028c575ec8bc71afbbf8faaffe2b2092a"},
+ {file = "SQLAlchemy-1.4.42-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:15d878929c30e41fb3d757a5853b680a561974a0168cd33a750be4ab93181628"},
+ {file = "SQLAlchemy-1.4.42-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:fa5b7eb2051e857bf83bade0641628efe5a88de189390725d3e6033a1fff4257"},
+ {file = "SQLAlchemy-1.4.42-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4e1c5f8182b4f89628d782a183d44db51b5af84abd6ce17ebb9804355c88a7b5"},
+ {file = "SQLAlchemy-1.4.42-cp36-cp36m-win32.whl", hash = "sha256:a7dd5b7b34a8ba8d181402d824b87c5cee8963cb2e23aa03dbfe8b1f1e417cde"},
+ {file = "SQLAlchemy-1.4.42-cp36-cp36m-win_amd64.whl", hash = "sha256:5ede1495174e69e273fad68ad45b6d25c135c1ce67723e40f6cf536cb515e20b"},
+ {file = "SQLAlchemy-1.4.42-cp37-cp37m-macosx_10_15_x86_64.whl", hash = "sha256:9256563506e040daddccaa948d055e006e971771768df3bb01feeb4386c242b0"},
+ {file = "SQLAlchemy-1.4.42-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4948b6c5f4e56693bbeff52f574279e4ff972ea3353f45967a14c30fb7ae2beb"},
+ {file = "SQLAlchemy-1.4.42-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:1811a0b19a08af7750c0b69e38dec3d46e47c4ec1d74b6184d69f12e1c99a5e0"},
+ {file = "SQLAlchemy-1.4.42-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9b01d9cd2f9096f688c71a3d0f33f3cd0af8549014e66a7a7dee6fc214a7277d"},
+ {file = "SQLAlchemy-1.4.42-cp37-cp37m-win32.whl", hash = "sha256:bd448b262544b47a2766c34c0364de830f7fb0772d9959c1c42ad61d91ab6565"},
+ {file = "SQLAlchemy-1.4.42-cp37-cp37m-win_amd64.whl", hash = "sha256:04f2598c70ea4a29b12d429a80fad3a5202d56dce19dd4916cc46a965a5ca2e9"},
+ {file = "SQLAlchemy-1.4.42-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:3ab7c158f98de6cb4f1faab2d12973b330c2878d0c6b689a8ca424c02d66e1b3"},
+ {file = "SQLAlchemy-1.4.42-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ee377eb5c878f7cefd633ab23c09e99d97c449dd999df639600f49b74725b80"},
+ {file = "SQLAlchemy-1.4.42-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:934472bb7d8666727746a75670a1f8d91a9cae8c464bba79da30a0f6faccd9e1"},
+ {file = "SQLAlchemy-1.4.42-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fdb94a3d1ba77ff2ef11912192c066f01e68416f554c194d769391638c8ad09a"},
+ {file = "SQLAlchemy-1.4.42-cp38-cp38-win32.whl", hash = "sha256:f0f574465b78f29f533976c06b913e54ab4980b9931b69aa9d306afff13a9471"},
+ {file = "SQLAlchemy-1.4.42-cp38-cp38-win_amd64.whl", hash = "sha256:a85723c00a636eed863adb11f1e8aaa36ad1c10089537823b4540948a8429798"},
+ {file = "SQLAlchemy-1.4.42-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:5ce6929417d5dce5ad1d3f147db81735a4a0573b8fb36e3f95500a06eaddd93e"},
+ {file = "SQLAlchemy-1.4.42-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:723e3b9374c1ce1b53564c863d1a6b2f1dc4e97b1c178d9b643b191d8b1be738"},
+ {file = "SQLAlchemy-1.4.42-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:876eb185911c8b95342b50a8c4435e1c625944b698a5b4a978ad2ffe74502908"},
+ {file = "SQLAlchemy-1.4.42-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2fd49af453e590884d9cdad3586415922a8e9bb669d874ee1dc55d2bc425aacd"},
+ {file = "SQLAlchemy-1.4.42-cp39-cp39-win32.whl", hash = "sha256:e4ef8cb3c5b326f839bfeb6af5f406ba02ad69a78c7aac0fbeeba994ad9bb48a"},
+ {file = "SQLAlchemy-1.4.42-cp39-cp39-win_amd64.whl", hash = "sha256:5f966b64c852592469a7eb759615bbd351571340b8b344f1d3fa2478b5a4c934"},
+ {file = "SQLAlchemy-1.4.42.tar.gz", hash = "sha256:177e41914c476ed1e1b77fd05966ea88c094053e17a85303c4ce007f88eff363"},
+]
+sqlparse = [
+ {file = "sqlparse-0.4.3-py3-none-any.whl", hash = "sha256:0323c0ec29cd52bceabc1b4d9d579e311f3e4961b98d174201d5622a23b85e34"},
+ {file = "sqlparse-0.4.3.tar.gz", hash = "sha256:69ca804846bb114d2ec380e4360a8a340db83f0ccf3afceeb1404df028f57268"},
+]
+submitit = [
+ {file = "submitit-1.4.2-py3-none-any.whl", hash = "sha256:f03711f039583f1f1bb7381dca645e114bf7f145e6fffc23ddfa91301a37925d"},
+ {file = "submitit-1.4.2.tar.gz", hash = "sha256:c82d43a0e1b71e16abf6920eb00ad06d7fb8887187f65d0f7781c68fa45bd908"},
+]
+tabulate = [
+ {file = "tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f"},
+ {file = "tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c"},
+]
+tensorboard = [
+ {file = "tensorboard-2.9.0-py3-none-any.whl", hash = "sha256:bd78211076dca5efa27260afacfaa96cd05c7db12a6c09cc76a1d6b2987ca621"},
+]
+tensorboard-data-server = [
+ {file = "tensorboard_data_server-0.6.1-py3-none-any.whl", hash = "sha256:809fe9887682d35c1f7d1f54f0f40f98bb1f771b14265b453ca051e2ce58fca7"},
+ {file = "tensorboard_data_server-0.6.1-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:fa8cef9be4fcae2f2363c88176638baf2da19c5ec90addb49b1cde05c95c88ee"},
+ {file = "tensorboard_data_server-0.6.1-py3-none-manylinux2010_x86_64.whl", hash = "sha256:d8237580755e58eff68d1f3abefb5b1e39ae5c8b127cc40920f9c4fb33f4b98a"},
+]
+tensorboard-plugin-wit = [
+ {file = "tensorboard_plugin_wit-1.8.1-py3-none-any.whl", hash = "sha256:ff26bdd583d155aa951ee3b152b3d0cffae8005dc697f72b44a8e8c2a77a8cbe"},
+]
+terminado = [
+ {file = "terminado-0.17.1-py3-none-any.whl", hash = "sha256:8650d44334eba354dd591129ca3124a6ba42c3d5b70df5051b6921d506fdaeae"},
+ {file = "terminado-0.17.1.tar.gz", hash = "sha256:6ccbbcd3a4f8a25a5ec04991f39a0b8db52dfcd487ea0e578d977e6752380333"},
+]
+testfixtures = [
+ {file = "testfixtures-6.18.5-py2.py3-none-any.whl", hash = "sha256:7de200e24f50a4a5d6da7019fb1197aaf5abd475efb2ec2422fdcf2f2eb98c1d"},
+ {file = "testfixtures-6.18.5.tar.gz", hash = "sha256:02dae883f567f5b70fd3ad3c9eefb95912e78ac90be6c7444b5e2f46bf572c84"},
+]
+threadpoolctl = [
+ {file = "threadpoolctl-3.1.0-py3-none-any.whl", hash = "sha256:8b99adda265feb6773280df41eece7b2e6561b772d21ffd52e372f999024907b"},
+ {file = "threadpoolctl-3.1.0.tar.gz", hash = "sha256:a335baacfaa4400ae1f0d8e3a58d6674d2f8828e3716bb2802c44955ad391380"},
+]
+timm = [
+ {file = "timm-0.6.7-py3-none-any.whl", hash = "sha256:4bbd7a5c9ae462ec7fec3d99ffc62ac2012010d755248e3de778d50bce5f6186"},
+ {file = "timm-0.6.7.tar.gz", hash = "sha256:340f907906695092cf53fe01a476aa14ad15763545b654bc122ea0daef23071f"},
+]
+tinycss2 = [
+ {file = "tinycss2-1.2.1-py3-none-any.whl", hash = "sha256:2b80a96d41e7c3914b8cda8bc7f705a4d9c49275616e886103dd839dfc847847"},
+ {file = "tinycss2-1.2.1.tar.gz", hash = "sha256:8cff3a8f066c2ec677c06dbc7b45619804a6938478d9d73c284b29d14ecb0627"},
+]
+toml = [
+ {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"},
+ {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},
+]
+tomli = [
+ {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
+ {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
+]
+torch = [
+ {file = "torch-1.12.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:9c038662db894a23e49e385df13d47b2a777ffd56d9bcd5b832593fab0a7e286"},
+ {file = "torch-1.12.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:4e1b9c14cf13fd2ab8d769529050629a0e68a6fc5cb8e84b4a3cc1dd8c4fe541"},
+ {file = "torch-1.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:e9c8f4a311ac29fc7e8e955cfb7733deb5dbe1bdaabf5d4af2765695824b7e0d"},
+ {file = "torch-1.12.1-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:976c3f997cea38ee91a0dd3c3a42322785414748d1761ef926b789dfa97c6134"},
+ {file = "torch-1.12.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:68104e4715a55c4bb29a85c6a8d57d820e0757da363be1ba680fa8cc5be17b52"},
+ {file = "torch-1.12.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:743784ccea0dc8f2a3fe6a536bec8c4763bd82c1352f314937cb4008d4805de1"},
+ {file = "torch-1.12.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:b5dbcca369800ce99ba7ae6dee3466607a66958afca3b740690d88168752abcf"},
+ {file = "torch-1.12.1-cp37-cp37m-win_amd64.whl", hash = "sha256:f3b52a634e62821e747e872084ab32fbcb01b7fa7dbb7471b6218279f02a178a"},
+ {file = "torch-1.12.1-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:8a34a2fbbaa07c921e1b203f59d3d6e00ed379f2b384445773bd14e328a5b6c8"},
+ {file = "torch-1.12.1-cp37-none-macosx_11_0_arm64.whl", hash = "sha256:42f639501928caabb9d1d55ddd17f07cd694de146686c24489ab8c615c2871f2"},
+ {file = "torch-1.12.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:0b44601ec56f7dd44ad8afc00846051162ef9c26a8579dda0a02194327f2d55e"},
+ {file = "torch-1.12.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:cd26d8c5640c3a28c526d41ccdca14cf1cbca0d0f2e14e8263a7ac17194ab1d2"},
+ {file = "torch-1.12.1-cp38-cp38-win_amd64.whl", hash = "sha256:42e115dab26f60c29e298559dbec88444175528b729ae994ec4c65d56fe267dd"},
+ {file = "torch-1.12.1-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:a8320ba9ad87e80ca5a6a016e46ada4d1ba0c54626e135d99b2129a4541c509d"},
+ {file = "torch-1.12.1-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:03e31c37711db2cd201e02de5826de875529e45a55631d317aadce2f1ed45aa8"},
+ {file = "torch-1.12.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:9b356aea223772cd754edb4d9ecf2a025909b8615a7668ac7d5130f86e7ec421"},
+ {file = "torch-1.12.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:6cf6f54b43c0c30335428195589bd00e764a6d27f3b9ba637aaa8c11aaf93073"},
+ {file = "torch-1.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:f00c721f489089dc6364a01fd84906348fe02243d0af737f944fddb36003400d"},
+ {file = "torch-1.12.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:bfec2843daa654f04fda23ba823af03e7b6f7650a873cdb726752d0e3718dada"},
+ {file = "torch-1.12.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:69fe2cae7c39ccadd65a123793d30e0db881f1c1927945519c5c17323131437e"},
+]
+torchmetrics = [
+ {file = "torchmetrics-0.8.2-py3-none-any.whl", hash = "sha256:19e4ed0305b8c02fd5765a9ff9360c9c622842c2b3491e497ddbf2aec7ce9c5a"},
+ {file = "torchmetrics-0.8.2.tar.gz", hash = "sha256:8cec51df230838b07e1bffe407fd98c25b8e1cdf820525a4ba6ef7f7e5ac4d89"},
+]
+torchtyping = [
+ {file = "torchtyping-0.1.4-py3-none-any.whl", hash = "sha256:485fb6ef3965c39b0de15f00d6f49373e0a3a6993e9733942a63c5e207d35390"},
+ {file = "torchtyping-0.1.4.tar.gz", hash = "sha256:4763375d17752641bd1bff0faaddade29be3c125fca6355e3cee7700e975fdb5"},
+]
+torchvision = [
+ {file = "torchvision-0.13.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:19286a733c69dcbd417b86793df807bd227db5786ed787c17297741a9b0d0fc7"},
+ {file = "torchvision-0.13.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:08f592ea61836ebeceb5c97f4d7a813b9d7dc651bbf7ce4401563ccfae6a21fc"},
+ {file = "torchvision-0.13.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:ef5fe3ec1848123cd0ec74c07658192b3147dcd38e507308c790d5943e87b88c"},
+ {file = "torchvision-0.13.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:099874088df104d54d8008f2a28539ca0117b512daed8bf3c2bbfa2b7ccb187a"},
+ {file = "torchvision-0.13.1-cp310-cp310-win_amd64.whl", hash = "sha256:8e4d02e4d8a203e0c09c10dfb478214c224d080d31efc0dbf36d9c4051f7f3c6"},
+ {file = "torchvision-0.13.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:5e631241bee3661de64f83616656224af2e3512eb2580da7c08e08b8c965a8ac"},
+ {file = "torchvision-0.13.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:899eec0b9f3b99b96d6f85b9aa58c002db41c672437677b553015b9135b3be7e"},
+ {file = "torchvision-0.13.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:83e9e2457f23110fd53b0177e1bc621518d6ea2108f570e853b768ce36b7c679"},
+ {file = "torchvision-0.13.1-cp37-cp37m-win_amd64.whl", hash = "sha256:7552e80fa222252b8b217a951c85e172a710ea4cad0ae0c06fbb67addece7871"},
+ {file = "torchvision-0.13.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f230a1a40ed70d51e463ce43df243ec520902f8725de2502e485efc5eea9d864"},
+ {file = "torchvision-0.13.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e9a563894f9fa40692e24d1aa58c3ef040450017cfed3598ff9637f404f3fe3b"},
+ {file = "torchvision-0.13.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:7cb789ceefe6dcd0dc8eeda37bfc45efb7cf34770eac9533861d51ca508eb5b3"},
+ {file = "torchvision-0.13.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:87c137f343197769a51333076e66bfcd576301d2cd8614b06657187c71b06c4f"},
+ {file = "torchvision-0.13.1-cp38-cp38-win_amd64.whl", hash = "sha256:4d8bf321c4380854ef04613935fdd415dce29d1088a7ff99e06e113f0efe9203"},
+ {file = "torchvision-0.13.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0298bae3b09ac361866088434008d82b99d6458fe8888c8df90720ef4b347d44"},
+ {file = "torchvision-0.13.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c5ed609c8bc88c575226400b2232e0309094477c82af38952e0373edef0003fd"},
+ {file = "torchvision-0.13.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:3567fb3def829229ec217c1e38f08c5128ff7fb65854cac17ebac358ff7aa309"},
+ {file = "torchvision-0.13.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:b167934a5943242da7b1e59318f911d2d253feeca0d13ad5d832b58eed943401"},
+ {file = "torchvision-0.13.1-cp39-cp39-win_amd64.whl", hash = "sha256:0e77706cc90462653620e336bb90daf03d7bf1b88c3a9a3037df8d111823a56e"},
+]
+tornado = [
+ {file = "tornado-6.2-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:20f638fd8cc85f3cbae3c732326e96addff0a15e22d80f049e00121651e82e72"},
+ {file = "tornado-6.2-cp37-abi3-macosx_10_9_x86_64.whl", hash = "sha256:87dcafae3e884462f90c90ecc200defe5e580a7fbbb4365eda7c7c1eb809ebc9"},
+ {file = "tornado-6.2-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ba09ef14ca9893954244fd872798b4ccb2367c165946ce2dd7376aebdde8e3ac"},
+ {file = "tornado-6.2-cp37-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b8150f721c101abdef99073bf66d3903e292d851bee51910839831caba341a75"},
+ {file = "tornado-6.2-cp37-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d3a2f5999215a3a06a4fc218026cd84c61b8b2b40ac5296a6db1f1451ef04c1e"},
+ {file = "tornado-6.2-cp37-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:5f8c52d219d4995388119af7ccaa0bcec289535747620116a58d830e7c25d8a8"},
+ {file = "tornado-6.2-cp37-abi3-musllinux_1_1_i686.whl", hash = "sha256:6fdfabffd8dfcb6cf887428849d30cf19a3ea34c2c248461e1f7d718ad30b66b"},
+ {file = "tornado-6.2-cp37-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:1d54d13ab8414ed44de07efecb97d4ef7c39f7438cf5e976ccd356bebb1b5fca"},
+ {file = "tornado-6.2-cp37-abi3-win32.whl", hash = "sha256:5c87076709343557ef8032934ce5f637dbb552efa7b21d08e89ae7619ed0eb23"},
+ {file = "tornado-6.2-cp37-abi3-win_amd64.whl", hash = "sha256:e5f923aa6a47e133d1cf87d60700889d7eae68988704e20c75fb2d65677a8e4b"},
+ {file = "tornado-6.2.tar.gz", hash = "sha256:9b630419bde84ec666bfd7ea0a4cb2a8a651c2d5cccdbdd1972a0c859dfc3c13"},
+]
+tqdm = [
+ {file = "tqdm-4.64.0-py2.py3-none-any.whl", hash = "sha256:74a2cdefe14d11442cedf3ba4e21a3b84ff9a2dbdc6cfae2c34addb2a14a5ea6"},
+ {file = "tqdm-4.64.0.tar.gz", hash = "sha256:40be55d30e200777a307a7585aee69e4eabb46b4ec6a4b4a5f2d9f11e7d5408d"},
+]
+traitlets = [
+ {file = "traitlets-5.9.0-py3-none-any.whl", hash = "sha256:9e6ec080259b9a5940c797d58b613b5e31441c2257b87c2e795c5228ae80d2d8"},
+ {file = "traitlets-5.9.0.tar.gz", hash = "sha256:f6cde21a9c68cf756af02035f72d5a723bf607e862e7be33ece505abf4a3bad9"},
+]
+typed-ast = [
+ {file = "typed_ast-1.5.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:669dd0c4167f6f2cd9f57041e03c3c2ebf9063d0757dc89f79ba1daa2bfca9d4"},
+ {file = "typed_ast-1.5.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:211260621ab1cd7324e0798d6be953d00b74e0428382991adfddb352252f1d62"},
+ {file = "typed_ast-1.5.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:267e3f78697a6c00c689c03db4876dd1efdfea2f251a5ad6555e82a26847b4ac"},
+ {file = "typed_ast-1.5.4-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:c542eeda69212fa10a7ada75e668876fdec5f856cd3d06829e6aa64ad17c8dfe"},
+ {file = "typed_ast-1.5.4-cp310-cp310-win_amd64.whl", hash = "sha256:a9916d2bb8865f973824fb47436fa45e1ebf2efd920f2b9f99342cb7fab93f72"},
+ {file = "typed_ast-1.5.4-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:79b1e0869db7c830ba6a981d58711c88b6677506e648496b1f64ac7d15633aec"},
+ {file = "typed_ast-1.5.4-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a94d55d142c9265f4ea46fab70977a1944ecae359ae867397757d836ea5a3f47"},
+ {file = "typed_ast-1.5.4-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:183afdf0ec5b1b211724dfef3d2cad2d767cbefac291f24d69b00546c1837fb6"},
+ {file = "typed_ast-1.5.4-cp36-cp36m-win_amd64.whl", hash = "sha256:639c5f0b21776605dd6c9dbe592d5228f021404dafd377e2b7ac046b0349b1a1"},
+ {file = "typed_ast-1.5.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:cf4afcfac006ece570e32d6fa90ab74a17245b83dfd6655a6f68568098345ff6"},
+ {file = "typed_ast-1.5.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ed855bbe3eb3715fca349c80174cfcfd699c2f9de574d40527b8429acae23a66"},
+ {file = "typed_ast-1.5.4-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:6778e1b2f81dfc7bc58e4b259363b83d2e509a65198e85d5700dfae4c6c8ff1c"},
+ {file = "typed_ast-1.5.4-cp37-cp37m-win_amd64.whl", hash = "sha256:0261195c2062caf107831e92a76764c81227dae162c4f75192c0d489faf751a2"},
+ {file = "typed_ast-1.5.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2efae9db7a8c05ad5547d522e7dbe62c83d838d3906a3716d1478b6c1d61388d"},
+ {file = "typed_ast-1.5.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7d5d014b7daa8b0bf2eaef684295acae12b036d79f54178b92a2b6a56f92278f"},
+ {file = "typed_ast-1.5.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:370788a63915e82fd6f212865a596a0fefcbb7d408bbbb13dea723d971ed8bdc"},
+ {file = "typed_ast-1.5.4-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:4e964b4ff86550a7a7d56345c7864b18f403f5bd7380edf44a3c1fb4ee7ac6c6"},
+ {file = "typed_ast-1.5.4-cp38-cp38-win_amd64.whl", hash = "sha256:683407d92dc953c8a7347119596f0b0e6c55eb98ebebd9b23437501b28dcbb8e"},
+ {file = "typed_ast-1.5.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4879da6c9b73443f97e731b617184a596ac1235fe91f98d279a7af36c796da35"},
+ {file = "typed_ast-1.5.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3e123d878ba170397916557d31c8f589951e353cc95fb7f24f6bb69adc1a8a97"},
+ {file = "typed_ast-1.5.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ebd9d7f80ccf7a82ac5f88c521115cc55d84e35bf8b446fcd7836eb6b98929a3"},
+ {file = "typed_ast-1.5.4-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:98f80dee3c03455e92796b58b98ff6ca0b2a6f652120c263efdba4d6c5e58f72"},
+ {file = "typed_ast-1.5.4-cp39-cp39-win_amd64.whl", hash = "sha256:0fdbcf2fef0ca421a3f5912555804296f0b0960f0418c440f5d6d3abb549f3e1"},
+ {file = "typed_ast-1.5.4.tar.gz", hash = "sha256:39e21ceb7388e4bb37f4c679d72707ed46c2fbf2a5609b8b8ebc4b067d977df2"},
+]
+typeguard = [
+ {file = "typeguard-2.13.3-py3-none-any.whl", hash = "sha256:5e3e3be01e887e7eafae5af63d1f36c849aaa94e3a0112097312aabfa16284f1"},
+ {file = "typeguard-2.13.3.tar.gz", hash = "sha256:00edaa8da3a133674796cf5ea87d9f4b4c367d77476e185e80251cc13dfbb8c4"},
+]
+typing-extensions = [
+ {file = "typing_extensions-4.3.0-py3-none-any.whl", hash = "sha256:25642c956049920a5aa49edcdd6ab1e06d7e5d467fc00e0506c44ac86fbfca02"},
+ {file = "typing_extensions-4.3.0.tar.gz", hash = "sha256:e6d2677a32f47fc7eb2795db1dd15c1f34eff616bcaf2cfb5e997f854fa1c4a6"},
+]
+urllib3 = [
+ {file = "urllib3-1.26.9-py2.py3-none-any.whl", hash = "sha256:44ece4d53fb1706f667c9bd1c648f5469a2ec925fcf3a776667042d645472c14"},
+ {file = "urllib3-1.26.9.tar.gz", hash = "sha256:aabaf16477806a5e1dd19aa41f8c2b7950dd3c746362d7e3223dbe6de6ac448e"},
+]
+virtualenv = [
+ {file = "virtualenv-20.15.1-py2.py3-none-any.whl", hash = "sha256:b30aefac647e86af6d82bfc944c556f8f1a9c90427b2fb4e3bfbf338cb82becf"},
+ {file = "virtualenv-20.15.1.tar.gz", hash = "sha256:288171134a2ff3bfb1a2f54f119e77cd1b81c29fc1265a2356f3e8d14c7d58c4"},
+]
+waitress = [
+ {file = "waitress-2.1.2-py3-none-any.whl", hash = "sha256:7500c9625927c8ec60f54377d590f67b30c8e70ef4b8894214ac6e4cad233d2a"},
+ {file = "waitress-2.1.2.tar.gz", hash = "sha256:780a4082c5fbc0fde6a2fcfe5e26e6efc1e8f425730863c04085769781f51eba"},
+]
+wcwidth = [
+ {file = "wcwidth-0.2.5-py2.py3-none-any.whl", hash = "sha256:beb4802a9cebb9144e99086eff703a642a13d6a0052920003a230f3294bbe784"},
+ {file = "wcwidth-0.2.5.tar.gz", hash = "sha256:c4d647b99872929fdb7bdcaa4fbe7f01413ed3d98077df798530e5b04f116c83"},
+]
+webdataset = [
+ {file = "webdataset-0.1.103-py3-none-any.whl", hash = "sha256:a2cf638a767bcf23b0a83fec321aa18f576f2ca0a40f7803a7b839706ad1076e"},
+ {file = "webdataset-0.1.103.tar.gz", hash = "sha256:ea41e983924bfc2678f2bea71208fb82b2116dbc1dcaf212f892e2748f925b5f"},
+]
+webencodings = [
+ {file = "webencodings-0.5.1-py2.py3-none-any.whl", hash = "sha256:a0af1213f3c2226497a97e2b3aa01a7e4bee4f403f95be16fc9acd2947514a78"},
+ {file = "webencodings-0.5.1.tar.gz", hash = "sha256:b36a1c245f2d304965eb4e0a82848379241dc04b865afcc4aab16748587e1923"},
+]
+websocket-client = [
+ {file = "websocket-client-1.4.1.tar.gz", hash = "sha256:f9611eb65c8241a67fb373bef040b3cf8ad377a9f6546a12b620b6511e8ea9ef"},
+ {file = "websocket_client-1.4.1-py3-none-any.whl", hash = "sha256:398909eb7e261f44b8f4bd474785b6ec5f5b499d4953342fe9755e01ef624090"},
+]
+werkzeug = [
+ {file = "Werkzeug-2.1.2-py3-none-any.whl", hash = "sha256:72a4b735692dd3135217911cbeaa1be5fa3f62bffb8745c5215420a03dc55255"},
+ {file = "Werkzeug-2.1.2.tar.gz", hash = "sha256:1ce08e8093ed67d638d63879fd1ba3735817f7a80de3674d293f5984f25fb6e6"},
+]
+wheel = [
+ {file = "wheel-0.37.1-py2.py3-none-any.whl", hash = "sha256:4bdcd7d840138086126cd09254dc6195fb4fc6f01c050a1d7236f2630db1d22a"},
+ {file = "wheel-0.37.1.tar.gz", hash = "sha256:e9a504e793efbca1b8e0e9cb979a249cf4a0a7b5b8c9e8b65a5e39d49529c1c4"},
+]
+widgetsnbextension = [
+ {file = "widgetsnbextension-4.0.7-py3-none-any.whl", hash = "sha256:be3228a73bbab189a16be2d4a3cd89ecbd4e31948bfdc64edac17dcdee3cd99c"},
+ {file = "widgetsnbextension-4.0.7.tar.gz", hash = "sha256:ea67c17a7cd4ae358f8f46c3b304c40698bc0423732e3f273321ee141232c8be"},
+]
+xmltodict = [
+ {file = "xmltodict-0.13.0-py2.py3-none-any.whl", hash = "sha256:aa89e8fd76320154a40d19a0df04a4695fb9dc5ba977cbb68ab3e4eb225e7852"},
+ {file = "xmltodict-0.13.0.tar.gz", hash = "sha256:341595a488e3e01a85a9d8911d8912fd922ede5fecc4dce437eb4b6c8d037e56"},
+]
+yarl = [
+ {file = "yarl-1.7.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f2a8508f7350512434e41065684076f640ecce176d262a7d54f0da41d99c5a95"},
+ {file = "yarl-1.7.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:da6df107b9ccfe52d3a48165e48d72db0eca3e3029b5b8cb4fe6ee3cb870ba8b"},
+ {file = "yarl-1.7.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a1d0894f238763717bdcfea74558c94e3bc34aeacd3351d769460c1a586a8b05"},
+ {file = "yarl-1.7.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dfe4b95b7e00c6635a72e2d00b478e8a28bfb122dc76349a06e20792eb53a523"},
+ {file = "yarl-1.7.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c145ab54702334c42237a6c6c4cc08703b6aa9b94e2f227ceb3d477d20c36c63"},
+ {file = "yarl-1.7.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1ca56f002eaf7998b5fcf73b2421790da9d2586331805f38acd9997743114e98"},
+ {file = "yarl-1.7.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:1d3d5ad8ea96bd6d643d80c7b8d5977b4e2fb1bab6c9da7322616fd26203d125"},
+ {file = "yarl-1.7.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:167ab7f64e409e9bdd99333fe8c67b5574a1f0495dcfd905bc7454e766729b9e"},
+ {file = "yarl-1.7.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:95a1873b6c0dd1c437fb3bb4a4aaa699a48c218ac7ca1e74b0bee0ab16c7d60d"},
+ {file = "yarl-1.7.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:6152224d0a1eb254f97df3997d79dadd8bb2c1a02ef283dbb34b97d4f8492d23"},
+ {file = "yarl-1.7.2-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:5bb7d54b8f61ba6eee541fba4b83d22b8a046b4ef4d8eb7f15a7e35db2e1e245"},
+ {file = "yarl-1.7.2-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:9c1f083e7e71b2dd01f7cd7434a5f88c15213194df38bc29b388ccdf1492b739"},
+ {file = "yarl-1.7.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f44477ae29025d8ea87ec308539f95963ffdc31a82f42ca9deecf2d505242e72"},
+ {file = "yarl-1.7.2-cp310-cp310-win32.whl", hash = "sha256:cff3ba513db55cc6a35076f32c4cdc27032bd075c9faef31fec749e64b45d26c"},
+ {file = "yarl-1.7.2-cp310-cp310-win_amd64.whl", hash = "sha256:c9c6d927e098c2d360695f2e9d38870b2e92e0919be07dbe339aefa32a090265"},
+ {file = "yarl-1.7.2-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:9b4c77d92d56a4c5027572752aa35082e40c561eec776048330d2907aead891d"},
+ {file = "yarl-1.7.2-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c01a89a44bb672c38f42b49cdb0ad667b116d731b3f4c896f72302ff77d71656"},
+ {file = "yarl-1.7.2-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c19324a1c5399b602f3b6e7db9478e5b1adf5cf58901996fc973fe4fccd73eed"},
+ {file = "yarl-1.7.2-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3abddf0b8e41445426d29f955b24aeecc83fa1072be1be4e0d194134a7d9baee"},
+ {file = "yarl-1.7.2-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:6a1a9fe17621af43e9b9fcea8bd088ba682c8192d744b386ee3c47b56eaabb2c"},
+ {file = "yarl-1.7.2-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:8b0915ee85150963a9504c10de4e4729ae700af11df0dc5550e6587ed7891e92"},
+ {file = "yarl-1.7.2-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:29e0656d5497733dcddc21797da5a2ab990c0cb9719f1f969e58a4abac66234d"},
+ {file = "yarl-1.7.2-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:bf19725fec28452474d9887a128e98dd67eee7b7d52e932e6949c532d820dc3b"},
+ {file = "yarl-1.7.2-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:d6f3d62e16c10e88d2168ba2d065aa374e3c538998ed04996cd373ff2036d64c"},
+ {file = "yarl-1.7.2-cp36-cp36m-musllinux_1_1_s390x.whl", hash = "sha256:ac10bbac36cd89eac19f4e51c032ba6b412b3892b685076f4acd2de18ca990aa"},
+ {file = "yarl-1.7.2-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:aa32aaa97d8b2ed4e54dc65d241a0da1c627454950f7d7b1f95b13985afd6c5d"},
+ {file = "yarl-1.7.2-cp36-cp36m-win32.whl", hash = "sha256:87f6e082bce21464857ba58b569370e7b547d239ca22248be68ea5d6b51464a1"},
+ {file = "yarl-1.7.2-cp36-cp36m-win_amd64.whl", hash = "sha256:ac35ccde589ab6a1870a484ed136d49a26bcd06b6a1c6397b1967ca13ceb3913"},
+ {file = "yarl-1.7.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:a467a431a0817a292121c13cbe637348b546e6ef47ca14a790aa2fa8cc93df63"},
+ {file = "yarl-1.7.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ab0c3274d0a846840bf6c27d2c60ba771a12e4d7586bf550eefc2df0b56b3b4"},
+ {file = "yarl-1.7.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d260d4dc495c05d6600264a197d9d6f7fc9347f21d2594926202fd08cf89a8ba"},
+ {file = "yarl-1.7.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fc4dd8b01a8112809e6b636b00f487846956402834a7fd59d46d4f4267181c41"},
+ {file = "yarl-1.7.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:c1164a2eac148d85bbdd23e07dfcc930f2e633220f3eb3c3e2a25f6148c2819e"},
+ {file = "yarl-1.7.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:67e94028817defe5e705079b10a8438b8cb56e7115fa01640e9c0bb3edf67332"},
+ {file = "yarl-1.7.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:89ccbf58e6a0ab89d487c92a490cb5660d06c3a47ca08872859672f9c511fc52"},
+ {file = "yarl-1.7.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:8cce6f9fa3df25f55521fbb5c7e4a736683148bcc0c75b21863789e5185f9185"},
+ {file = "yarl-1.7.2-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:211fcd65c58bf250fb994b53bc45a442ddc9f441f6fec53e65de8cba48ded986"},
+ {file = "yarl-1.7.2-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:c10ea1e80a697cf7d80d1ed414b5cb8f1eec07d618f54637067ae3c0334133c4"},
+ {file = "yarl-1.7.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:52690eb521d690ab041c3919666bea13ab9fbff80d615ec16fa81a297131276b"},
+ {file = "yarl-1.7.2-cp37-cp37m-win32.whl", hash = "sha256:695ba021a9e04418507fa930d5f0704edbce47076bdcfeeaba1c83683e5649d1"},
+ {file = "yarl-1.7.2-cp37-cp37m-win_amd64.whl", hash = "sha256:c17965ff3706beedafd458c452bf15bac693ecd146a60a06a214614dc097a271"},
+ {file = "yarl-1.7.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:fce78593346c014d0d986b7ebc80d782b7f5e19843ca798ed62f8e3ba8728576"},
+ {file = "yarl-1.7.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c2a1ac41a6aa980db03d098a5531f13985edcb451bcd9d00670b03129922cd0d"},
+ {file = "yarl-1.7.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:39d5493c5ecd75c8093fa7700a2fb5c94fe28c839c8e40144b7ab7ccba6938c8"},
+ {file = "yarl-1.7.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1eb6480ef366d75b54c68164094a6a560c247370a68c02dddb11f20c4c6d3c9d"},
+ {file = "yarl-1.7.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5ba63585a89c9885f18331a55d25fe81dc2d82b71311ff8bd378fc8004202ff6"},
+ {file = "yarl-1.7.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e39378894ee6ae9f555ae2de332d513a5763276a9265f8e7cbaeb1b1ee74623a"},
+ {file = "yarl-1.7.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:c0910c6b6c31359d2f6184828888c983d54d09d581a4a23547a35f1d0b9484b1"},
+ {file = "yarl-1.7.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:6feca8b6bfb9eef6ee057628e71e1734caf520a907b6ec0d62839e8293e945c0"},
+ {file = "yarl-1.7.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:8300401dc88cad23f5b4e4c1226f44a5aa696436a4026e456fe0e5d2f7f486e6"},
+ {file = "yarl-1.7.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:788713c2896f426a4e166b11f4ec538b5736294ebf7d5f654ae445fd44270832"},
+ {file = "yarl-1.7.2-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:fd547ec596d90c8676e369dd8a581a21227fe9b4ad37d0dc7feb4ccf544c2d59"},
+ {file = "yarl-1.7.2-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:737e401cd0c493f7e3dd4db72aca11cfe069531c9761b8ea474926936b3c57c8"},
+ {file = "yarl-1.7.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:baf81561f2972fb895e7844882898bda1eef4b07b5b385bcd308d2098f1a767b"},
+ {file = "yarl-1.7.2-cp38-cp38-win32.whl", hash = "sha256:ede3b46cdb719c794427dcce9d8beb4abe8b9aa1e97526cc20de9bd6583ad1ef"},
+ {file = "yarl-1.7.2-cp38-cp38-win_amd64.whl", hash = "sha256:cc8b7a7254c0fc3187d43d6cb54b5032d2365efd1df0cd1749c0c4df5f0ad45f"},
+ {file = "yarl-1.7.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:580c1f15500e137a8c37053e4cbf6058944d4c114701fa59944607505c2fe3a0"},
+ {file = "yarl-1.7.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3ec1d9a0d7780416e657f1e405ba35ec1ba453a4f1511eb8b9fbab81cb8b3ce1"},
+ {file = "yarl-1.7.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3bf8cfe8856708ede6a73907bf0501f2dc4e104085e070a41f5d88e7faf237f3"},
+ {file = "yarl-1.7.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1be4bbb3d27a4e9aa5f3df2ab61e3701ce8fcbd3e9846dbce7c033a7e8136746"},
+ {file = "yarl-1.7.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:534b047277a9a19d858cde163aba93f3e1677d5acd92f7d10ace419d478540de"},
+ {file = "yarl-1.7.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c6ddcd80d79c96eb19c354d9dca95291589c5954099836b7c8d29278a7ec0bda"},
+ {file = "yarl-1.7.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:9bfcd43c65fbb339dc7086b5315750efa42a34eefad0256ba114cd8ad3896f4b"},
+ {file = "yarl-1.7.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f64394bd7ceef1237cc604b5a89bf748c95982a84bcd3c4bbeb40f685c810794"},
+ {file = "yarl-1.7.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:044daf3012e43d4b3538562da94a88fb12a6490652dbc29fb19adfa02cf72eac"},
+ {file = "yarl-1.7.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:368bcf400247318382cc150aaa632582d0780b28ee6053cd80268c7e72796dec"},
+ {file = "yarl-1.7.2-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:bab827163113177aee910adb1f48ff7af31ee0289f434f7e22d10baf624a6dfe"},
+ {file = "yarl-1.7.2-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:0cba38120db72123db7c58322fa69e3c0efa933040ffb586c3a87c063ec7cae8"},
+ {file = "yarl-1.7.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:59218fef177296451b23214c91ea3aba7858b4ae3306dde120224cfe0f7a6ee8"},
+ {file = "yarl-1.7.2-cp39-cp39-win32.whl", hash = "sha256:1edc172dcca3f11b38a9d5c7505c83c1913c0addc99cd28e993efeaafdfaa18d"},
+ {file = "yarl-1.7.2-cp39-cp39-win_amd64.whl", hash = "sha256:797c2c412b04403d2da075fb93c123df35239cd7b4cc4e0cd9e5839b73f52c58"},
+ {file = "yarl-1.7.2.tar.gz", hash = "sha256:45399b46d60c253327a460e99856752009fcee5f5d3c80b2f7c0cae1c38d56dd"},
+]
+zipp = [
+ {file = "zipp-3.8.0-py3-none-any.whl", hash = "sha256:c4f6e5bbf48e74f7a38e7cc5b0480ff42b0ae5178957d564d18932525d5cf099"},
+ {file = "zipp-3.8.0.tar.gz", hash = "sha256:56bf8aadb83c24db6c4b577e13de374ccfb67da2078beba1d037c17980bf43ad"},
+]
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..bda9242
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,74 @@
+[tool.poetry]
+name = "ocl"
+version = "0.1.0"
+description = ""
+authors = ["Max Horn "]
+
+[tool.poetry.scripts]
+ocl_train = "ocl.cli.train:train"
+ocl_eval = "ocl.cli.eval:evaluate"
+ocl_compute_dataset_size = "ocl.cli.compute_dataset_size:compute_size"
+
+[tool.poetry.dependencies]
+python = ">=3.7.1,<3.9"
+webdataset = "^0.1.103"
+# There seems to be an issue in torch 1.12.x with masking and multi-head
+# attention. This prevents the usage of makes without a batch dimension.
+# Staying with torch 1.11.x version for now.
+torch = "1.12.*"
+pytorch-lightning = "^1.5.10"
+hydra-zen = "^0.7.0"
+torchtyping = "^0.1.4"
+hydra-core = "^1.2.0"
+pluggy = "^1.0.0"
+importlib-metadata = "4.2"
+torchvision = "0.13.*"
+Pillow = "9.0.1" # Newer versions of pillow seem to result in segmentation faults.
+torchmetrics = "^0.8.1"
+matplotlib = "^3.5.1"
+moviepy = "^1.0.3"
+scipy = "<=1.8"
+awscli = "^1.22.90"
+scikit-learn = "^1.0.2"
+pyamg = "^4.2.3"
+botocore = { extras = ["crt"], version = "^1.27.22" }
+timm = {version = "0.6.7", optional = true}
+hydra-submitit-launcher = { version = "^1.2.0", optional = true }
+decord = "0.6.0"
+motmetrics = "^1.2.5"
+
+ftfy = {version = "^6.1.1", optional = true}
+regex = {version = "^2022.7.9", optional = true}
+mlflow = {version = "^1.29.0", optional = true}
+einops = "^0.6.0"
+jupyter = "^1.0.0"
+
+[tool.poetry.dev-dependencies]
+black = "^22.1.0"
+pytest = "^7.0.1"
+flake8 = "^4.0.1"
+flake8-isort = "^4.1.1"
+pre-commit = "^2.17.0"
+flake8-tidy-imports = "^4.6.0"
+flake8-bugbear = "^22.1.11"
+flake8-docstrings = "^1.6.0"
+
+[tool.poetry.extras]
+timm = ["timm"]
+clip = ["clip", "ftfy", "regex"]
+submitit = ["hydra-submitit-launcher"]
+mlflow = ["mlflow"]
+
+[build-system]
+requires = ["poetry-core<=1.0.4"]
+build-backend = "poetry.core.masonry.api"
+
+[tool.black]
+line-length = 101
+target-version = ["py38"]
+
+[tool.isort]
+profile = "black"
+line_length = 101
+skip_gitignore = true
+remove_redundant_aliases = true
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 0000000..0ad3a9b
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,25 @@
+[flake8]
+select=
+ # F: errors from pyflake
+ F,
+ # W, E: warnings/errors from pycodestyle (PEP8)
+ W, E,
+ # I: problems with imports
+ I,
+ # B: bugbear warnings ("likely bugs and design problems")
+ B,
+ # D: docstring warnings from pydocstyle
+ D
+ignore=
+ # E203: whitespace before ':' (incompatible with black)
+ E203,
+ # E731: do not use a lambda expression, use a def (local def is often ugly)
+ E731,
+ # W503: line break before binary operator (incompatible with black)
+ W503,
+ # D1: docstring warnings related to missing documentation
+ D1
+max-line-length = 101
+ban-relative-imports = true
+docstring-convention = google
+exclude = .*,__pycache__,./outputs