Skip to content

googleinterns/fairness_ssl

Repository files navigation

Towards Group Robustness in the Presence of Partial Group Labels

  • This is not an officially supported Google product.

Install dependencies

sudo apt install python3-dev python3-virtualenv python3-tk imagemagick
virtualenv -p python3 --system-site-packages env
. env/bin/activate
pip install -r requirements.txt
  • The code has been tested on Ubuntu 18.04 with CUDA 9.1.

Datasets

Download or generate the datasets as follows:

  • Waterbirds: Download a tarball of the dataset. Place the contents under data/dataset/celeba_dataset directory.
  • CMNIST: Download the MNIST dataset from this website. Place the contents in data/dataset/mnist_dataset/ directory.
  • Adult: The dataset can be downloaded from UCI repository. Place the contents in data/dataset/raw/ directory.
  • CelebA: Download the dataset from kaggle. Place it under data/datasets/ directory.

Running experiment

Several run scripts are provided in the bin/* directory. The files bin/run_* indicate a single hyper-param runs. The run files provided with the code contain best hyper-parameters attained after cross-validation.

An example command with relevant flags is provided below. Details on each flag is available in the file train_and_eval_loop.py.

    python train_and_eval_loop.py \
        --dataset 'Waterbirds' \
        --model_type 'resnet50' \
        --method 'worstoffdro' \
        --optimizer 'SGD' \
        --learning_rate 1e-5 \
        --noflag_saveckpt \
        --batch_size 128 \
        --num_epoch 300 \
        --weight_decay 1.0  \
        --lab_split 0.1 \
        --worstoffdro_stepsize 0.001 \
        --worstoffdro_marginals=.53,.25,.07,.15 \
        --epsilon=0.001 \
        --ckpt_prefix "results" \
        --flag_run_all