A multi-branch deep learning model for ECG-based binary disease classification, originally developed for Pulmonary Embolism (PE) detection.
Each branch independently processes a different group of ECG leads — converted to time-frequency scalogram images via Continuous Wavelet Transform (CWT) — and the three feature streams are fused into a single classification head.
Note: Training data is not included in this repository due to patient privacy constraints. See Data format for the expected input structure.
ECG Signal (12-lead + Long Lead II, ≥13 channels)
│
├── V1–V6 ──────────────► CWT Scalogram ──► ResNet-18 (Arm 1) ──► 512-dim
│ │
├── I / II / III / aVR / aVL / aVF ──► CWT Scalogram ──► ResNet-18 (Arm 2) ──► 512-dim
│ │
└── Long Lead II ──────────► CWT Scalogram ──► ResNet-18 (Arm 3) ──► 512-dim
│
┌────────────────┘
▼
Concatenate (1536-dim)
│
BN → Linear(1536→256) → ReLU → Dropout(0.5)
│
Linear(256 → 2)
│
PE positive / PE negative
| Component | Detail |
|---|---|
| Backbone | ResNet-18, ImageNet pretrained |
| Signal → image | Morlet CWT, scales 1–127 → 224 × 224 RGB |
| Fine-tuning | Only layer4 unfrozen per arm |
| Fusion | Concatenation + BN + 2-layer MLP |
| Learning rates | Fusion head: lr; backbone layer4: lr / 10 |
| Regularization | Weight decay 1e-4, Dropout 0.5 |
| LR schedule | ReduceLROnPlateau (patience = 3, factor = 0.1) |
| Checkpointing | Best model saved on validation loss improvement |
.
├── rest_net.py # Model definition, dataset, training loop
├── data_prep.py # One-time data splitting utility
└── README.md
Python 3.10+ recommended.
torch>=2.0
torchvision>=0.15
PyWavelets
Pillow
numpy
tqdm
matplotlib
pip install torch torchvision PyWavelets Pillow numpy tqdm matplotlibdata_prep.py and rest_net.py expect pickle files containing a Python list of tuples:
(filename: str, ecg_tensor: torch.Tensor, label: int)| Field | Description |
|---|---|
filename |
Record identifier string (not used by the model) |
ecg_tensor |
Shape (1, num_leads, signal_length) — leading dim is squeezed automatically |
label |
1 = positive class (PE), 0 = negative class |
Minimum channel requirement: indices 0–12 (see LEAD_INDEX_MAPPING in rest_net.py).
Index 0 1 2 3 4 5 6 7 8 9 10 11 12
Lead I II III aVR aVL aVF V1 V2 V3 V4 V5 V6 Long Lead II
Update LEAD_INDEX_MAPPING in rest_net.py if your channel order differs.
Starting from a raw pickle file with labels 1 / -1:
python data_prep.py --input data_before.pklOptions:
| Flag | Default | Description |
|---|---|---|
--input |
(required) | Raw input pickle file |
--out_dir |
. |
Output directory |
--seed |
42 |
Random seed |
Outputs: data_train.pkl, data_val.pkl, data_test.pkl (80 / 10 / 10 split).
python rest_net.pyCommon flags:
| Flag | Default | Description |
|---|---|---|
--train_pkl |
data_train.pkl |
Training set |
--val_pkl |
data_val.pkl |
Validation set |
--test_pkl |
data_test.pkl |
Test set |
--output_model |
multi_arm_resnet.pth |
Checkpoint path |
--epochs |
20 |
Training epochs |
--batch_size |
16 |
Batch size (reduce if GPU OOM) |
--lr |
0.001 |
Base learning rate |
--num_workers |
2 |
DataLoader workers |
--no_pretrain |
off | Disable ImageNet weights |
Example with custom paths:
python rest_net.py \
--train_pkl splits/data_train.pkl \
--val_pkl splits/data_val.pkl \
--test_pkl splits/data_test.pkl \
--epochs 30 --batch_size 8 \
--output_model checkpoints/best.pthfrom google.colab import drive
drive.mount('/content/drive')
!python rest_net.py \
--train_pkl "/content/drive/MyDrive/ECG/data_train.pkl" \
--val_pkl "/content/drive/MyDrive/ECG/data_val.pkl" \
--test_pkl "/content/drive/MyDrive/ECG/data_test.pkl"| File | Description |
|---|---|
multi_arm_resnet.pth |
Best model weights (lowest validation loss) |
training_curves.png |
Loss and accuracy curves (train vs. validation) |
| Bug | Original | Fixed |
|---|---|---|
pd.Series called without importing pandas |
NameError at label count |
Replaced with collections.Counter in data_prep.py |
criterion undefined in training loop |
NameError mid-training |
nn.CrossEntropyLoss() defined before loop |
| Model never saved | torch.save commented out |
Saved on every validation loss improvement |
Colab google.colab import at top level |
Crashes outside Colab | Removed; Colab usage documented in README |
tqdm.notebook import crash in terminal |
ImportError |
Falls back to standard tqdm |
| Flat script with no entry point | Not importable as module | Wrapped in main() with if __name__ == "__main__" |
If you build on this work, please cite:
YHHuan, "RestNet_for_EKG_detection", GitHub, 2024.
https://github.com/YHHuan/RestNet_for_EKG_detection
MIT License