DRESS: Disentangled Representation-based Self-Supervised Meta-Learning for Diverse Tasks [arXiv]
Authors: Wei Cui, Tongzi Wu, Jesse C. Cresswell, Yi Sui, Keyvan Golestan
This repository contains the official implementation of the paper DRESS: Disentangled Representation-based Self-Supervised Meta-Learning for Diverse Tasks. It includes both training and evaluation code.
The code files within the repository are organized as follows:
main.py
: the main entrance point of the program.partition_generators.py
: implementation of generating supervised and self-supervised partitions on each dataset.task_generator.py
: implementation of generating few-shot learning tasks from any given partition.utils.py
: implementation of helper functions.
The sub-folders within the repository are as follows:
scripts/
: the folder including the scripts to train, evaluate, and obtain visulizations.encoders/
: the folder containing classes of encoders for obtaining the latent spaces.dataset_loaders/
: the folder containing scripts for loading each of the dataset for experiments.baselines/
: the folder containing implementations of baseline methods.analyze_results/
: the folder containing scripts for post-processing results.visualization_results/
: the folder containing visualizations on constructed tasks via DRESS.
Create a folder named data/
under the main directory to house the raw data.
The datasets experimented are loaded from their respective dataset loader script under dataset_loaders/
. The source data preparations are as follows:
- smallNORB: automatically downloaded within our script via the
tensorflow_datasets
package. - shapes3D: download
3dshapes.h5
from Google Cloud Storage and place it underdata/shapes3d/
. - causal3D: download
trainset.tar.gz
andtestset.tar.gz
from the dataset homepage and extract them underdata/causal3d/train/
anddata/causal3d/test/
resectively. - MPI3D: download
mpi3d_toy.npz
from this link and place it underdata/mpi3d/
. - CelebA: automatically downloaded within our script via the
torchvision
package.
Simply install an anaconda environment using the environment.yml
file under this repository.