Skip to content

Commit

Permalink
First push
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewsong90 committed May 21, 2024
0 parents commit f4e0c01
Show file tree
Hide file tree
Showing 306 changed files with 73,107 additions and 0 deletions.
13 changes: 13 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
.DS_Store
*.pth
*.pth.tar
*.pt
**/__pycache__/
*bwh*.csv
.ipynb_checkpoints
.vscode
**/wandb
src/results/
**/embeddings/
**/prototypes/
**/test.ipynb
158 changes: 158 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# PANTHER


<b>Morphological Prototyping for Unsupervised Slide Representation Learning in Computational Pathology</b>, CVPR 2024.
<br><em>Andrew H. Song\*, Richard J. Chen\*, Tong Ding, Drew F.K. Williamson, Guillaume Jaume, Faisal Mahmood</em></br>


<img src="docs/panther.png" width="300px" align="right" />

[Paper Link] | [Cite](#cite)

**Abstract:** Representation learning of pathology whole-slide images (WSIs) has been has primarily relied on weak supervision with Multiple Instance Learning (MIL). However, the slide representations resulting from this approach are highly tailored to specific clinical tasks, which limits their expressivity and generalization, particularly in scenarios with limited data. Instead, we hypothesize that morphological redundancy in tissue can be leveraged to build a task-agnostic slide representation in an unsupervised fashion. To this end, we introduce **PANTHER**, a prototype-based approach rooted in the Gaussian mixture model that summarizes the set of WSI patches into a much smaller set of morphological prototypes. Specifically, each patch is assumed to have been generated from a mixture distribution, where each mixture component represents a morphological exemplar. Utilizing the estimated mixture parameters, we then construct a compact slide representation that can be readily used for a wide range of downstream tasks.
By performing an extensive evaluation of **PANTHER** on subtyping and survival tasks using 13 datasets, we show that 1) **PANTHER** outperforms or is on par with supervised MIL baselines and 2) the analysis of morphological prototypes brings new qualitative and quantitative insights into model interpretability.


<img src='docs/fig_main.jpg' width="1400px" align="center"/>

## Updates
- 05/06/2024: The first version of PANTHER codebase is now live!

## Installation
Please run the following command to create PANTHER conda environment.
```shell
conda env create -f environment.yml
```

## PANTHER Walkthrough
There are two workflows for PANTHER, depending on the use case.
- **Workflow 1**
- Step 0 ⇒ Step 1 ⇒ Step 2A ⇒ (Step 3)
- Useful for constructing *unsupervised* slide representations, without any specific downstream tasks at hand.
- **Workflow 2**
- Step 0 ⇒ Step 1 ⇒ Step 2B ⇒ (Step 3)
- Useful for when there is a specific downstream task at hand. The slide representations will be constructed.

### Step 0. Dataset organization
**Data csv**: The data csv files (with appropriate splits, e.g., train, test) are placed within `src/splits` with appropriate folder structure. For example, for classification task on ebrains, we would have
```bash
splits/
├── ebrains
├── train.csv
├── val.csv
└── test.csv
```

Alternatively, for 5-fold cross-validation survival task on TCGA BRCA, we would have
```bash
splits/
├── TCGA_BRCA_survival_k=0
├── train.csv
├── val.csv
└── test.csv
├── ...

├── TCGA_BRCA_survival_k=4
├── train.csv
├── val.csv
└── test.csv
```

**Patch features**: For the following steps, we assume that features for each patch have already been extracted and that each WSI is represented as a set of patch features. For examples of patch feature extraction, please refer to [CLAM](https://github.com/mahmoodlab/CLAM).

The code assumes that the features are either in `.h5` or `.pt` formats - the feature directory path `FEAT_DIR` has to end with the directory `feats_h5/` if the features are in `.h5` format, and `feats_pt/` for `.pt` format.



### Step 1. Prototype construction
For prototype construction, we use K-means clustering across all training WSIs. We recommend using GPU-based FAISS when using large number of patch features for clustering. For example, we can use the following command to find 16 prototypes (of 1,024 dimension each) using FAISS from WSIs corresponding to `SPLIT_DIR/train.csv`.
```shell
CUDA_VISIBLE_DEVICES=0 python -m training.main_prototype \
--mode faiss \
--data_source FEAT_DIR_1,FEAT_DIR_2 \
--split_dir SPLIT_DIR \
--split_names train \
--in_dim 1024 \
--n_proto_patches 1000000 \
--n_proto 16 \
--n_init 5 \
--seed 1 \
--num_workers 10 \

```
The list of parameters is as follows:
- `mode`: 'faiss' uses GPU-enabled K-means clustering to find the prototypes. 'kmeans' uses sklearn K-means clustering on CPU ('faiss' or 'kmeans').
- `in_dim`: Dimension of the patch features, dependent on the feature encoder.
- `n_proto`: Number of prototypes.
- `n_proto_patches`: Number of patch features to use per prototype. In total, `n_proto * n_proto_patches` features are used for finding prototypes.
- `n_init`: Number of K-means initializations to try.

The prototypes will be saved in the `SPLIT_DIR/prototypes` folder.


A concrete script example of using TCGA-BRCA patch features can be found below.
```shell
cd src
./scripts/prototype/brca.sh 0
```
This will initiate the script `scripts/prototype/clustering.sh` for K-means clustering. Detailed explanations for clustering hyperparameters can be found in **clustering.sh**.

**Visualization**:

### Step 2A. Unsupervised slide representation construction
Once the prototypes are constructed, we can use **PANTHER** or **OT** to construct unsupervised slide representations.
```shell
cd src
./scripts/embedding/brca.sh 0 panther
```
This step will create two files in the `SPLIT_DIR/embeddings` folder: 1) **(\*.pkl)** original slide-level representation and 2) **(\*_tokenized.pkl)** slide-level representation tokenized into each prototype and statistics (e.g., mixture probability, mean, cov). Note that for **OT**, mixture probability is uniform and `cov=None`.

Alternatively, if you want to construct slide representations as part of a classification or survival downstream task, you can skip Step 2A and go straight to Step 2B.


### Step 2B. Training downstream model
Once the prototypes are constructed in Step 1 (Step 2A is not required), we can run a downstream task
```shell
cd src
./scripts/survival/brca_surv.sh 0 panther
```

### Step 3. Visualization

To visualize GMM mixture proportions in prototypical assignment maps in PANTHER, see the accompanying [notebook](src/notebooks/prototypical_assignment_map_visualization.ipynb).

<img src='docs/prototypical_assignment_map.jpg' width="1400px" align="center"/>

## Additional Findings in PANTHER
- We observe that using high-quality pretrained ROI encoders (such as UNI) leads to significant performances across all MIL and set-based learning methods (see the **Supplement**). Specifically:
- - When using Resnet-50 (ImageNet Transfer) and CTransPath features, unsupervised set representation methods such as OT and PANTHER underperform MIL methods (using the same features). With UNI features, OT and PANTHER can readily outperform MIL, and should be considered strong baselines when evaluating slide-level tasks.
- - DeepAttnMISL with UNI features becomes a strong MIL baseline. This can be attributed to DeepAttnMISL dependent on K-Means for cluster pooling (which depends on high-quality representations).
- - With unsupervised slide representations extracted per WSI (via OT or PANTHER), training survival models on WSIs is now much more stable since you can directly use Cox loss (instead of NLL). Across all of our ablation experiments, PANTHER with UNI features always achieved C-Index > 0.6.

## PANTHER Limitations
As unsupervised slide representations in PANTHER are created using non-parametric techniques such as K-Means Clustering and GMMs (which rely on Euclidean distance or dot product to compare embeddings), we note the following limitations:
- Dependent on the degree of dataset shift between the train and test distributions (due to variable H&E stain variability, known as image acquisition shift), prototype assignment for certain WSIs may lead to results in which all patches are assigned to a single prototype. This is exemplified in TCGA which has site-specific biases, and is thus an important consideration when considering using PANTHER (or any non-parametric approach) for histopathologic biomarker discovery.
- When clustering over a WSI dataset composed of millions to billions of patches, clustering with only `C=16` clusters will likely underfit the dataset, and also lead to collapse of all patches in a WSI falling under a single prototype. Empirically, we found `C=16` to outperform `C=32` in supervised settings. However, in settings such as biomarker discovery or unsupervised tissue segmentation, using more prototypes may improve performance.

## Acknowledgements
If you find our work useful in your research or if you use parts of this code please cite our paper:

```bibtext
@inproceedings{song2024morphological,
title={Morphological Prototyping for Unsupervised Slide Representation Learning in Computational Pathology},
author={Song, Andrew H and Chen, Richard J and Ding, Tong and Williamson, Drew FK and Jaume, Guillaume and Mahmood, Faisal},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2024},
}
```

The code for **PANTHER** was adapted and inspired by the fantastic works of [DIEM](https://openreview.net/forum?id=MXdFBmHT4C), [OTK](https://openreview.net/forum?id=ZK6vTvb84s), and [CLAM](https://github.com/mahmoodlab/CLAM). Boilerplate code for setting up supervised MIL benchmarks was developed by Ming Y. Lu.

## Issues
- Please open new threads or report issues directly (for urgent blockers) to `[email protected]`.
- Immediate response to minor issues may not be available.

## License and Usage
[Mahmood Lab](https://faisal.ai) - This code is made available under the CC-BY-NC-ND 4.0 License and is available for non-commercial academic purposes.

<img src=docs/joint_logo.png>
Binary file added docs/fig_main.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/joint_logo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/panther.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/prototypical_assignment_map.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
150 changes: 150 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
name: panther
channels:
- pytorch
- nvidia
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- asttokens=2.4.1=pyhd8ed1ab_0
- blas=1.0=openblas
- bzip2=1.0.8=h5eee18b_5
- ca-certificates=2024.2.2=hbcca054_0
- comm=0.2.2=pyhd8ed1ab_0
- cudatoolkit=11.4.1=h8ab8bb3_9
- debugpy=1.6.7=py310h6a678d5_0
- decorator=5.1.1=pyhd8ed1ab_0
- entrypoints=0.4=pyhd8ed1ab_0
- exceptiongroup=1.2.0=pyhd8ed1ab_2
- executing=2.0.1=pyhd8ed1ab_0
- faiss-gpu=1.7.4=py3.10_hc0239a3_0_cuda11.4
- ipykernel=6.29.3=pyhd33586a_0
- ipython=8.22.2=pyh707e725_0
- jedi=0.19.1=pyhd8ed1ab_0
- jupyter_client=7.3.4=pyhd8ed1ab_0
- jupyter_core=5.7.2=py310hff52083_0
- ld_impl_linux-64=2.38=h1181459_1
- libfaiss=1.7.4=h13c3c6d_0_cuda11.4
- libffi=3.4.4=h6a678d5_0
- libgcc-ng=11.2.0=h1234567_1
- libgfortran-ng=11.2.0=h00389a5_1
- libgfortran5=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libopenblas=0.3.21=h043d6bf_0
- libsodium=1.0.18=h36c2ea0_1
- libstdcxx-ng=11.2.0=h1234567_1
- libuuid=1.41.5=h5eee18b_0
- matplotlib-inline=0.1.7=pyhd8ed1ab_0
- ncurses=6.4=h6a678d5_0
- nest-asyncio=1.6.0=pyhd8ed1ab_0
- numpy=1.26.4=py310heeff2f4_0
- numpy-base=1.26.4=py310h8a23956_0
- openssl=3.0.13=h7f8727e_0
- packaging=24.0=pyhd8ed1ab_0
- parso=0.8.4=pyhd8ed1ab_0
- pexpect=4.9.0=pyhd8ed1ab_0
- pickleshare=0.7.5=py_1003
- pip=23.3.1=py310h06a4308_0
- platformdirs=4.2.0=pyhd8ed1ab_0
- prompt-toolkit=3.0.42=pyha770c72_0
- psutil=5.9.0=py310h5eee18b_0
- ptyprocess=0.7.0=pyhd3deb0d_0
- pure_eval=0.2.2=pyhd8ed1ab_0
- pygments=2.17.2=pyhd8ed1ab_0
- python=3.10.14=h955ad1f_0
- python-dateutil=2.9.0=pyhd8ed1ab_0
- python_abi=3.10=2_cp310
- pyzmq=25.1.2=py310h6a678d5_0
- readline=8.2=h5eee18b_0
- setuptools=68.2.2=py310h06a4308_0
- six=1.16.0=pyh6c4a22f_0
- sqlite=3.41.2=h5eee18b_0
- stack_data=0.6.2=pyhd8ed1ab_0
- tk=8.6.12=h1ccaba5_0
- tornado=6.1=py310h5764c6d_3
- traitlets=5.14.3=pyhd8ed1ab_0
- typing_extensions=4.11.0=pyha770c72_0
- wcwidth=0.2.13=pyhd8ed1ab_0
- wheel=0.41.2=py310h06a4308_0
- xz=5.4.6=h5eee18b_0
- zeromq=4.3.5=h6a678d5_0
- zlib=1.2.13=h5eee18b_0
- pip:
- absl-py==2.1.0
- appdirs==1.4.4
- certifi==2024.2.2
- charset-normalizer==3.3.2
- click==8.1.7
- contourpy==1.2.1
- cycler==0.12.1
- docker-pycreds==0.4.0
- ecos==2.0.13
- einops==0.7.0
- filelock==3.13.4
- fonttools==4.51.0
- fsspec==2024.3.1
- gitdb==4.0.11
- gitpython==3.1.43
- grpcio==1.62.2
- h5py==3.11.0
- huggingface-hub==0.22.2
- idna==3.7
- intel-openmp==2024.1.0
- jinja2==3.1.3
- joblib==1.4.0
- kiwisolver==1.4.5
- markdown==3.6
- markupsafe==2.1.5
- matplotlib==3.8.4
- mkl==2024.1.0
- mpmath==1.3.0
- networkx==3.3
- numexpr==2.10.0
- nvidia-cublas-cu12==12.1.3.1
- nvidia-cuda-cupti-cu12==12.1.105
- nvidia-cuda-nvrtc-cu12==12.1.105
- nvidia-cuda-runtime-cu12==12.1.105
- nvidia-cudnn-cu12==8.9.2.26
- nvidia-cufft-cu12==11.0.2.54
- nvidia-curand-cu12==10.3.2.106
- nvidia-cusolver-cu12==11.4.5.107
- nvidia-cusparse-cu12==12.1.0.106
- nvidia-nccl-cu12==2.19.3
- nvidia-nvjitlink-cu12==12.4.127
- nvidia-nvtx-cu12==12.1.105
- nystrom-attention==0.0.12
- osqp==0.6.5
- pandas==2.2.2
- pillow==10.3.0
- protobuf==4.25.3
- pyparsing==3.1.2
- pytz==2024.1
- pyyaml==6.0.1
- qdldl==0.1.7.post2
- regex==2024.4.16
- requests==2.31.0
- safetensors==0.4.3
- scikit-learn==1.3.2
- scikit-survival==0.22.2
- scipy==1.11.4
- seaborn==0.13.2
- sentry-sdk==1.45.0
- setproctitle==1.3.3
- smmap==5.0.1
- sympy==1.12
- tbb==2021.12.0
- tensorboard==2.16.2
- tensorboard-data-server==0.7.2
- threadpoolctl==3.4.0
- tokenizers==0.19.1
- torch==2.2.2
- torchvision==0.17.2
- tqdm==4.66.2
- transformers==4.40.0
- triton==2.2.0
- tzdata==2024.1
- urllib3==2.2.1
- wandb==0.16.6
- werkzeug==3.0.2
prefix: /home/andrew/anaconda3/envs/panther
Empty file added src/__init__.py
Empty file.
9 changes: 9 additions & 0 deletions src/configs/ABMIL_default/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"gate": true,
"in_dim": 768,
"n_classes": 2,
"embed_dim": 512,
"attn_dim": 384,
"n_fc_layers": 1,
"dropout": 0.25
}
9 changes: 9 additions & 0 deletions src/configs/ABMIL_tiny/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"gate": true,
"in_dim": 64,
"n_classes": 2,
"embed_dim": 64,
"attn_dim": 64,
"n_fc_layers": 1,
"dropout": 0.25
}
8 changes: 8 additions & 0 deletions src/configs/H2T_default/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"in_dim": 768,
"n_classes": 2,
"out_size": 8,
"load_proto": false,
"proto_path": ".",
"fix_proto": false
}
15 changes: 15 additions & 0 deletions src/configs/IndivMLPEmb_Indiv/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"in_dim": 2049,
"n_classes": 4,
"shared_embed_dim": 256,
"indiv_embed_dim": 128,
"postcat_embed_dim": 1024,
"shared_mlp": false,
"indiv_mlps": true,
"postcat_mlp": false,
"n_fc_layers": 1,
"shared_dropout": 0.1,
"indiv_dropout": 0.1,
"postcat_dropout": 0.1,
"p": 32
}
15 changes: 15 additions & 0 deletions src/configs/IndivMLPEmb_IndivPost/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"in_dim": 2049,
"n_classes": 4,
"shared_embed_dim": 256,
"indiv_embed_dim": 128,
"postcat_embed_dim": 1024,
"shared_mlp": false,
"indiv_mlps": true,
"postcat_mlp": true,
"n_fc_layers": 1,
"shared_dropout": 0.1,
"indiv_dropout": 0.1,
"postcat_dropout": 0.1,
"p": 32
}
Loading

0 comments on commit f4e0c01

Please sign in to comment.