Skip to content
/ D-TRAK Public

Intriguing Properties of Data Attribution on Diffusion Models (ICLR 2024)

License

Notifications You must be signed in to change notification settings

sail-sg/D-TRAK

Folders and files

NameName
Last commit message
Last commit date

Latest commit

e2a0c2e · Jan 23, 2024

History

4 Commits
Nov 1, 2023
Nov 1, 2023
Nov 1, 2023
Nov 1, 2023
Nov 1, 2023
Nov 1, 2023
Nov 1, 2023
Nov 1, 2023
Jan 23, 2024
Nov 1, 2023
Nov 1, 2023

Repository files navigation

Intriguing Properties of Data Attribution on Diffusion Models

[Project Page] | [arXiv] | [Data Repository


TL, DR:

We report counter-intuitive observations that theoretically unjustified design choices 
for attributing diffusion models empirically outperform previous baselines 
by a large margin.

Proponents and opponents visualization on ArtBench-2 using TRAK and D-TRAK with various # of timesteps (10 or 100). For each sample of interest, 5 most positive influential training samples and 3 most negative influential training samples are given together with the influence scores (below each sample).


Counterfactual visualization on CIFAR-2 Counterfactual visualization on ArtBench-2

How to run

Quickstart

Check quickstart.ipynb to conduct data attribution on pre-trained diffusion models loaded from huggingface directly!

Replicating the paper's results

Setup

To get started, follow these steps:

  1. Clone the GitHub Repository: Begin by cloning the repository using the command:
    git clone https://github.com/sail-sg/D-TRAK.git
  2. Set Up Python Environment: Ensure you have a version 3.8. name:
    conda create -n dtrak python=3.8 -y
    conda activate dtrak
  3. Install Dependencies: Install the necessary dependencies by running:
    pip install -r requirements.txt

Commands for LDS evaluation

We provide the commands to run experiments on CIFAR-2. It is easy to transfer to other datasets.

  1. Data pre-processing:

    cd CIFAR2

    Run 00_EDA.ipynb to create dataset splits and subsets of the training set.

  2. Train a diffusion model and generate images:

    bash scripts/run_train.sh 0 18888 5000-0.5
    bash scripts/run_gen.sh 0 0 5000-0.5
  3. Construct the LDS benchmark:

    Train 64 models corresponding to 64 subsets of the training set

    bash scripts/run_lds_val_sub.sh 0 18888 5000-0.5 0 63

    Evaluate the model outputs on the validation set

    bash scripts/run_eval_lds_val_sub.sh 0 0 5000-0.5 idx_val.pkl 0 63
    bash scripts/run_eval_lds_val_sub.sh 0 1 5000-0.5 idx_val.pkl 0 63
    bash scripts/run_eval_lds_val_sub.sh 0 2 5000-0.5 idx_val.pkl 0 63

    Evaluate the model outputs on the generation set

    bash scripts/run_eval_lds_val_sub.sh 0 0 5000-0.5 idx_gen.pkl 0 63
    bash scripts/run_eval_lds_val_sub.sh 0 1 5000-0.5 idx_gen.pkl 0 63
    bash scripts/run_eval_lds_val_sub.sh 0 2 5000-0.5 idx_gen.pkl 0 63
  4. Compute gradients:

    We shard the training set into 5 parts, each has 1000 examples.

    Use the following commands to compute the gradients to be used for TRAK.

    bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 0 ddpm/checkpoint-8000 loss uniform 10 32768
    bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 1 ddpm/checkpoint-8000 loss uniform 10 32768
    bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 2 ddpm/checkpoint-8000 loss uniform 10 32768
    bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 3 ddpm/checkpoint-8000 loss uniform 10 32768
    bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 4 ddpm/checkpoint-8000 loss uniform 10 32768
    bash scripts/run_grad.sh 0 0 5000-0.5 idx-val.pkl 0 ddpm/checkpoint-8000 loss uniform 10 32768
    bash scripts/run_grad.sh 0 0 5000-0.5 idx-gen.pkl 0 ddpm/checkpoint-8000 loss uniform 10 32768

    Use the following commands to compute the gradients to be used for D-TRAK.

    bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 0 ddpm/checkpoint-8000 mean-squared-l2-norm uniform 10 32768
    bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 1 ddpm/checkpoint-8000 mean-squared-l2-norm uniform 10 32768
    bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 2 ddpm/checkpoint-8000 mean-squared-l2-norm uniform 10 32768
    bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 3 ddpm/checkpoint-8000 mean-squared-l2-norm uniform 10 32768
    bash scripts/run_grad.sh 0 0 5000-0.5 idx-train.pkl 4 ddpm/checkpoint-8000 mean-squared-l2-norm uniform 10 32768
    bash scripts/run_grad.sh 0 0 5000-0.5 idx-val.pkl 0 ddpm/checkpoint-8000 mean-squared-l2-norm uniform 10 32768
    bash scripts/run_grad.sh 0 0 5000-0.5 idx-gen.pkl 0 ddpm/checkpoint-8000 mean-squared-l2-norm uniform 10 32768
  5. Compute the TRAK/D-TRAK attributions and evaluate the LDS scores

    Run notebooks in methods/04_if.

    The implementations of other baselines can also be found in methods.

Commands for counterfactual evaluation

  1. Data pre-processing

    Run this notebook first to get the indices of those training examples to be removed.

  2. Retrain models after removing the top-influenctial training examples

    bash scripts/run_counter.sh 0 18888 5000-0.5 0 59
  3. Generate images using the retrained models

    Run 02_counter.ipynb

  4. Measure l2 distance

    Run 03_counter_eval_l2.ipynb

  5. Measure CLIP cosine similarity

    Run 03_counter_eval_clip.ipynb

Bibtex

If you find this project useful in your research, please consider citing our paper:

@inproceedings{
zheng2023intriguing,
title={Intriguing Properties of Data Attribution on Diffusion Models},
  author={Zheng, Xiaosen and Pang, Tianyu and Du, Chao and Jiang, Jing and Lin, Min},
booktitle={International Conference on Learning Representations (ICLR)},
year={2024},
}