Pre-trained Natural Image Models are Few-shot Learners for Medical Image Classification: COVID-19 Diagnosis as an Example
This repository contains the code and datasets that hleps to direct users to reproduce our reported results.
- Demo of Visual Reconsruction by MAE in Medical Images
- Pre-training Datasets
- Fine-tuning Datasets
- Transitional training scheme
- Pre-trained models for fine-tuning
- Intermediate models for fine-tuning
- Pre-training recipes
- Fine-tuning recipes
As shown in the figure, Transitional Training Scheme comprises two phases. In phase-1 (depicted within fine-tune a pre-trained natural image model and produce an intermediate model. In phase- 2 (depicted within the yellow frame), this intermediate model is further fine-tuned on the target few-shot dataset U orig and yield the final few-shot COVID-19 diagnosis model.
Name | Abbreviation | COVID | Normal | Bacteria | Dataset size |
---|---|---|---|---|---|
chest-ct-scans-with-COVID-19 | CHE | 27,781 | 0 | 0 | 27,781 |
COVID-19_ct_scans1 | CCT | 1,762 | 0 | 0 | 1,762 |
COVID-19-20_v2 | C1920 | 6,723 | 0 | 0 | 6,723 |
COVID-19-AR | CAR | 18,592 | 0 | 0 | 18,592 |
COVID-19-CT-segmentation-dataset | CCS | 110 | 0 | 0 | 110 |
COVID19-CT-Dataset1000+ | C1000 | 307,765 | 0 | 0 | 307,765 |
CT_Images_in_COVID-19 | CIC | 32,548 | 0 | 0 | 32,548 |
MIDRC-RICORD-1A | MRA | 9,833 | 0 | 0 | 9,833 |
MIDRC-RICORD-1B | MRB | 5,501 | 0 | 0 | 5,501 |
sarscov2-ctscan-dataset | SC | 1,252 | 1,229 | 0 | 2,481 |
SIRM-COVID-19 (data removed by the host now) | SIRM | 599 | 0 | 0 | 599 |
COVIDX-CT-2A | CXC | 93,975 | 59,510 | 0 | 153,485 |
large-COVID-19-ct-slice-dataset | LC | 7,593 | 6,893 | 0 | 14,486 |
COVID-19-and-common-pneumonia-chest-CT-dataset | CC | 41,813 | 0 | 55,219 | 97,032 |
Summation | / | 555,847 | 67,632 | 55,219 | 678,698 |
The downstream fine-tuning dataset is sourced from the work of COVID-CT-Dataset: A CT Scan Dataset about COVID-19, which has established a real-world COVID-19 CT image dataset, UCSD-AI4H-COVID-CT (named U_orig for abbreviation in our work, and ‘orig’ means ‘original’).This is relatively popular in the deep learning-based medical image analysis community, especially the area related to COVID-19 diagnosis.
One thing noteworthy is that U_orig does not account for the patient level of CT scans, implying that each patient has multiple adjacent CT images (slices) in one CT scan. For a more stringent performance evaluation in our experiments, we select only one slice from each patient, creating a smaller dataset named U_sani (‘sani’ means ‘sanitized’). U_sani consists of 131 positive samples and 158 negative samples.
Name | COVID | Non-COVID | ||
---|---|---|---|---|
samples | patients | samples | patients | |
U__orig | 349 | 216 | 397 | 171 |
U__sani | 131 | 131 | 158 | 158 |
Pre-trained model | Backbone | Training dataset | Training method | Image domain |
---|---|---|---|---|
ViT-B/16_IN1K | ViT-B/16 | IN1K | SL | Natural image |
ViT-L/16_IN1K | ViT-L/16 | IN1K | SL | Natural image |
ViT-B/16_CXC | ViT-B/16 | CXC | SL | Medical image |
ViT-L/16_CXC | ViT-L/16 | CXC | SL | Medical image |
MAE-B/16_IN1K | ViT-B/16 | IN1K | SSL | Natural image |
MAE-L/16_IN1K | ViT-L/16 | IN1K | SSL | Natural image |
MAE-B/16_IN1K | ViT-B/16 | CXC | SSL | Medical image |
MAE-L/16_IN1K | ViT-L/16 | CXC | SSL | Medical image |
MAE-B/16_DATA13 | ViT-B/16 | DATA13 | SSL | Medical image |
MAE-L/16_DATA13 | ViT-L/16 | DATA13 | SSL | Medical image |
MAE-B/16_DATA14 | ViT-B/16 | DATA14 | SSL | Medical image |
MAE-L/16_DATA14 | ViT-L/16 | DATA14 | SSL | Medical image |
Intermediate model | Backbone | Base model | Dataset for Phase-1 |
---|---|---|---|
ViT-B/16_IN1K/CXC | ViT-B/16 | ViT-B/16_IN1K | CXC |
ViT-L/16_IN1K/CXC | ViT-L/16 | ViT-L/16_IN1K | CXC |
MAE-B/16_IN1K/CXC | ViT-B/16 | MAE-B/16_IN1K | CXC |
MAE-L/16_IN1K/CXC | ViT-L/16 | MAE-L/16_IN1K | CXC |
MAE-B/16_DATA13/CXC | ViT-B/16 | MAE-B/16_DATA13 | CXC |
MAE-B/16_DATA13/CXC | ViT-L/16 | MAE-L/16_DATA13 | CXC |
The pre-training recipes are in PRETRAIN.md.
The fine-tuning recipes are in FINETUNE.md.
Results of training from scratch using U_orig across 12 different random seeds and 4 different data splits (5:4:1, 6:3:1, 7:2:1 and 8:1:1). The average and standard deviation of performance scores across 48 (12*4) experimental trials with ViT-B/16 and ViT-L/16 are reported.
Backbone | ViT-B/16 | ViT-L/16 |
---|---|---|
Accuracy | 0.5853±0.0325 | 0.5897±0.0321 |
F1 | 0.4827±0.0870 | 0.4988±0.0895 |
AUC | 0.6001±0.0245 | 0.6132±0.0266 |
Results of full fine-tuning of pre-trained ViT models and MAE models using U_orig across 12 different random seeds, with data split = 2:3:5. Self-supervised MAE models consistently outperform supervised ViT models, both on IN1K and CXC. Moreover, within the MAE models, those pre-trained on natural images (IN1K) demonstrate better performance than those trained on medical images (CXC), likely due to the larger size and greater diversity of IN1K.
Pre-trained Model | Dataset | Accuracy | F1 | AUC | Type |
---|---|---|---|---|---|
ViT-B/16_IN1K | U_orig | 0.7580±0.0180 | 0.7300±0.0248 | 0.8193±0.0175 | SL |
ViT-L/16_IN1K | U_orig | 0.7689±0.0206 | 0.7432±0.0290 | 0.8354±0.0243 | SL |
ViT-B/16_CXC | U_orig | 0.7052±0.0124 | 0.6741±0.0321 | 0.7825±0.0175 | SL |
ViT-L/16_CXC | U_orig | 0.7014±0.0193 | 0.6701±0.0389 | 0.7731±0.0146 | SL |
MAE-B/16_IN1K | U_orig | 0.8119±0.0399 | 0.7893±0.0615 | 0.9054±0.0297 | SSL |
MAE-L/16_IN1K | U_orig | 0.8144±0.0314 | 0.7850±0.0542 | 0.9084±0.0171 | SSL |
MAE-B/16_CXC | U_orig | 0.7947±0.0383 | 0.7809±0.0481 | 0.8794±0.0343 | SSL |
MAE-L/16_CXC | U_orig | 0.7213±0.0821 | 0.6775±0.1567 | 0.8060±0.0714 | SSL |
Results of full fine-tuning of pre-trained medical image models using U_orig across 12 different random seeds, with data split = 2:3:5. The models pre-trained on DATA14 consistently outperform those pre-trained on DATA13. Moreover, larger models need a larger quantity of pre-training data to enhance overall performance and stability.
Pre-trained Model | Dataset | Accuracy | F1 | AUC | Type |
---|---|---|---|---|---|
MAE-B/16_DATA13 | U_orig | 0.7999±0.0669 | 0.7773±0.1002 | 0.8869±0.0596 | SSL |
MAE-L/16_DATA13 | U_orig | 0.7288±0.1253 | 0.5803±0.3290 | 0.8300±0.0928 | SSL |
MAE-B/16_DATA14 | U_orig | 0.8004±0.0719 | 0.7938±0.0708 | 0.8795±0.0646 | SSL |
MAE-L/16_DATA14 | U_orig | 0.7469±0.1370 | 0.5983±0.3657 | 0.8465±0.0922 | SSL |
By fine-tuning the intermediate models, we achieve the best performance in few-shot real-world COVID-19 classification tasks from CT images (detailed in the paper).
The following table provides the results of full fine-tuning of intermediate models using U_orig across 12 different random seeds, with data split = 2:3:5. Intermediate models consistently demonstrate a remarkable performance gain compared to the corresponding pre-trained base models. Moreover, intermediate MAE models outperform intermediate ViT models significantly. Notably, MAE-L/16_IN1K/CXC performs the best among all the intermediate models.
Intermediate model | Dataset | Accuracy | F1 | AUC |
---|---|---|---|---|
ViT-B/16_IN1K/CXC | U_orig | 0.7712 ± 0.0188 | 0.7464 ± 0.0206 | 0.8456 ± 0.0114 |
ViT-L/16_IN1K/CXC | U_orig | 0.7718 ± 0.0172 | 0.7453 ± 0.0313 | 0.8494 ± 0.0159 |
MAE-B/16_IN1K/CXC | U_orig | 0.8554 ± 0.0222 | 0.8445 ± 0.0281 | 0.9337 ± 0.0113 |
MAE-L/16_IN1K/CXC | U_orig | 0.8680 ± 0.0157 | 0.8586 ± 0.0164 | 0.9380 ± 0.0125 |
MAE-B/16_DATA13/CXC | U_orig | 0.8434$ ± 0.0231 | 0.8319 ± 0.0287 | 0.9258 ± 0.0117 |
MAE-B/16_DATA13/CXC | U_orig | 0.8385$ ± 0.0255 | 0.8355 ± 0.0242 | 0.9217 ± 0.0111 |
Approach | Accuracy | F1 | AUC |
---|---|---|---|
Ours | 0.9026 | 0.8914 | 0.9689 |
Work 1 | 0.8910 | 0.8960 | 0.9810 |
Ours | 0.9113 | 0.9032 | 0.9514 |
Work 2 | 0.8600 | 0.8500 | 0.9400 |
Results of full fine-tuning of different MAE models on U_orig and U_sani across 46 different data splits, with random seed = 42. Despite the rigorous evaluation condition and scarcity of samples in ultra few-shot cases, intermediate models consistently outperform pre-trained base models. Notably, MAE-L/16_IN1K/CXC performs the best, both on U_orig and U_sani under the challenging conditions.
Pre-trained/Intermediate Model | Dataset | Accuracy | F1 | AUC |
---|---|---|---|---|
MAE-B/16_IN1K | U_orig | 0.8306±0.0518 | 0.8170±0.0547 | 0.9209±0.0358 |
U_sani | 0.7589±0.0647 | 0.7458±0.0562 | 0.8595±0.0505 | |
MAE-L/16_IN1K | U_orig | 0.8359±0.0433 | 0.8156±0.0588 | 0.9213±0.0348 |
U_sani | 0.7528±0.0528 | 0.7128±0.1302 | 0.8546±0.0444 | |
MAE-B/16_DATA14 | U_orig | 0.7985±0.0915 | 0.7781±0.0987 | 0.8750±0.1033 |
U_sani | 0.6928±0.0628 | 0.6543±0.1421 | 0.7863±0.0576 | |
MAE-L/16_DATA14 | U_orig | 0.7130±0.1056 | 0.6271±0.2648 | 0.7958±0.0999 |
U_sani | 0.6426±0.0745 | 0.5190±0.2916 | 0.7551±0.1081 | |
MAE-B/16_IN1K/CXC | U_orig | 0.8635±0.0291 | 0.8539±0.0295 | 0.9406±0.0274 |
U_sani | 0.8165±0.0578 | 0.8070±0.0437 | 0.8928±0.0462 | |
MAE-L/16_IN1K/CXC | U_orig | 0.8723±0.0317 | 0.8649±0.0272 | 0.9384±0.0265 |
U_sani | 0.8346±0.0754 | 0.8290±0.0733 | 0.8949±0.0831 | |
- This project is under the CC-BY-NC 4.0 license. See LICENSE for details.