Skip to content

halleewong/ScribblePrompt

Folders and files

NameName
Last commit message
Last commit date

Latest commit

93d65d0 · Feb 25, 2025

History

52 Commits
Apr 22, 2024
Jan 27, 2025
Dec 31, 2024
Dec 12, 2024
Jan 15, 2025
Feb 25, 2025
Jan 19, 2024
Nov 30, 2023
Jul 11, 2024
Dec 7, 2023
Jan 4, 2025
Feb 20, 2025
Dec 16, 2023

Repository files navigation

Spaces Colab

ScribblePrompt

Official implementation of ScribblePrompt: Fast and Flexible Interactive Segmentation for any Biomedical Image accepted at ECCV 2024

Hallee E. Wong, Marianne Rakic, John Guttag, Adrian V. Dalca

Updates

  • (2024-12-31) Released example training code
  • (2024-12-12) Released full prompt simulation code
  • (2024-07-01) ScribblePrompt has been accepted to ECCV 2024!
  • (2024-06-17) ScribblePrompt won the Bench-to-Bedside Award at the CVPR 2024 DCAMI Workshop!
  • (2024-04-16) Released MedScribble -- a diverse dataset of segmentation tasks with scribble annotations
  • (2024-04-15) An updated version of the paper is on arXiv!
  • (2024-04-14) Added Google Colab Tutorial
  • (2024-01-19) Released scribble simulation code
  • (2023-12-15) Released model code and weights
  • (2023-12-12) Paper and online demo released

Overview

ScribblePrompt is an interactive segmentation tool that enables users to segment unseen structures in medical images using scribbles, clicks, and bounding boxes.

Try ScribblePrompt

Models

We provide checkpoints for two versions of ScribblePrompt:

  • ScribblePrompt-UNet with an efficient fully-convolutional architecture

  • ScribblePrompt-SAM based on the Segment Anything Model

Both models have been trained with iterative scribbles, click, and bounding box interactions on a diverse collection of 65 medical imaging datasets with both real and synthetic labels.

MedScribble Dataset

We release MedScribble, a dataset of multi-annotator scribble annotations for diverse biomedical image segmentation tasks, under ./MedScribble. See the readme for more info and ./MedScribble/tutorial.ipynb for a preview of the data.

Installation

You can install scribbleprompt in two ways:

  • With pip:
pip install git+https://github.com/halleewong/ScribblePrompt.git
  • Manually: cloning it and installing dependencies
git clone https://github.com/halleewong/ScribblePrompt
python -m pip install -r ./ScribblePrompt/requirements.txt
export PYTHONPATH="$PYTHONPATH:$(realpath ./ScribblePrompt)"

The following optional dependencies are necessary for the local demo:

pip install gradio==3.40.1

Getting Started

First, download the model checkpoints to ./checkpoints.

To run an interactive demo locally:

python demos/app.py

To instantiate ScribblePrompt-UNet and make a prediction:

from scribbleprompt import ScribblePromptUNet

sp_unet = ScribblePromptUNet()

mask = sp_unet.predict(
    image,        # (B, 1, H, W) 
    point_coords, # (B, n, 2)
    point_labels, # (B, n)
    scribbles,    # (B, 2, H, W)
    box,          # (B, n, 4)
    mask_input,   # (B, 1, H, W)
) # -> (B, 1, H, W) 

To instantiate ScribblePrompt-SAM and make a prediction:

from scribbleprompt import ScribblePromptSAM

sp_sam = ScribblePromptSAM()

mask, img_features, low_res_logits = sp_sam.predict(
    image,        # (B, 1, H, W) 
    point_coords, # (B, n, 2)
    point_labels, # (B, n)
    scribbles,    # (B, 2, H, W)
    box,          # (B, n, 4)
    mask_input,   # (B, 1, 256, 256)
) # -> (B, 1, H, W), (B, 16, 256, 256), (B, 1, 256, 256)

For best results, image should have spatial dimensions ( H , W ) = ( 128 , 128 ) and pixel values min-max normalized to the [ 0 , 1 ] range.

For ScribblePrompt-UNet, mask_input should be the logits from the previous prediction. For ScribblePrompt-SAM, mask_input should be low_res_logits from the previous prediction.

Training

Note: our training code requires the pylot library. The inference code above does not. We recommend installing via pip:

pip install git+https://github.com/JJGO/pylot.git@87191921033c4391546fd88c5f963ccab7597995

The configuration settings for training are controlled by yaml config files. We provide two example configs in ./configs for fine-tuning from the pre-trained ScribblePrompt-UNet weights as well as training from scratch on an example dataset.

To fine-tune ScribblePrompt-UNet from the pre-trained weights:

python scribbleprompt/experiment/unet.py -config finetune_unet.yaml 

To train a model from scratch:

python scribbleprompt/experiment/unet.py -config train_unet.yaml 

For a more in-depth tutorial see ./notebooks/training.ipynb.

To Do

  • Release Gradio demo
  • Release model code and weights
  • Release jupyter notebook tutorial
  • Release scribble simulation code
  • Release MedScribble dataset
  • Release training code
  • Release segmentation labels collected using ScribblePrompt

Acknowledgements

  • Our training code builds on the pylot library for deep learning experiment management. We also make use of data augmentation code originally developed for UniverSeg. Thanks to @JJGO for sharing this code!

  • We use functions from voxsynth for applying random deformations during scribble simulation

  • Code for ScribblePrompt-SAM builds on Segment Anything. Thanks to Meta AI for open-sourcing the model.

Citation

If you find our work or any of our materials useful, please cite our paper:

@article{wong2024scribbleprompt,
  title={ScribblePrompt: Fast and Flexible Interactive Segmentation for Any Biomedical Image},
  author={Hallee E. Wong and Marianne Rakic and John Guttag and Adrian V. Dalca},
  journal={European Conference on Computer Vision (ECCV)},
  year={2024},
}

License

Code for this project is released under the Apache 2.0 License

About

[ECCV 2024] ScribblePrompt: Fast and Flexible Interactive Segmentation for Any Medical Image

Topics

Resources

License

Citation

Stars

Watchers

Forks