Skip to content

YHHuan/RestNet_for_EKG_detection

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 

Repository files navigation

RestNet for ECG Disease Detection

A multi-branch deep learning model for ECG-based binary disease classification, originally developed for Pulmonary Embolism (PE) detection.

Each branch independently processes a different group of ECG leads — converted to time-frequency scalogram images via Continuous Wavelet Transform (CWT) — and the three feature streams are fused into a single classification head.

Note: Training data is not included in this repository due to patient privacy constraints. See Data format for the expected input structure.


Architecture

ECG Signal  (12-lead + Long Lead II, ≥13 channels)
         │
         ├── V1–V6 ──────────────► CWT Scalogram ──► ResNet-18 (Arm 1) ──► 512-dim
         │                                                                       │
         ├── I / II / III / aVR / aVL / aVF ──► CWT Scalogram ──► ResNet-18 (Arm 2) ──► 512-dim
         │                                                                       │
         └── Long Lead II ──────────► CWT Scalogram ──► ResNet-18 (Arm 3) ──► 512-dim
                                                                                │
                                                               ┌────────────────┘
                                                               ▼
                                               Concatenate  (1536-dim)
                                                               │
                                           BN → Linear(1536→256) → ReLU → Dropout(0.5)
                                                               │
                                                   Linear(256 → 2)
                                                               │
                                                 PE positive / PE negative

Design choices

Component Detail
Backbone ResNet-18, ImageNet pretrained
Signal → image Morlet CWT, scales 1–127 → 224 × 224 RGB
Fine-tuning Only layer4 unfrozen per arm
Fusion Concatenation + BN + 2-layer MLP
Learning rates Fusion head: lr; backbone layer4: lr / 10
Regularization Weight decay 1e-4, Dropout 0.5
LR schedule ReduceLROnPlateau (patience = 3, factor = 0.1)
Checkpointing Best model saved on validation loss improvement

Repository structure

.
├── rest_net.py      # Model definition, dataset, training loop
├── data_prep.py     # One-time data splitting utility
└── README.md

Requirements

Python 3.10+ recommended.

torch>=2.0
torchvision>=0.15
PyWavelets
Pillow
numpy
tqdm
matplotlib
pip install torch torchvision PyWavelets Pillow numpy tqdm matplotlib

Data format

data_prep.py and rest_net.py expect pickle files containing a Python list of tuples:

(filename: str,  ecg_tensor: torch.Tensor,  label: int)
Field Description
filename Record identifier string (not used by the model)
ecg_tensor Shape (1, num_leads, signal_length) — leading dim is squeezed automatically
label 1 = positive class (PE), 0 = negative class

Minimum channel requirement: indices 0–12 (see LEAD_INDEX_MAPPING in rest_net.py).

Default lead index mapping

Index  0    1    2     3     4     5     6   7   8   9   10  11    12
Lead   I   II   III   aVR   aVL   aVF   V1  V2  V3  V4  V5  V6  Long Lead II

Update LEAD_INDEX_MAPPING in rest_net.py if your channel order differs.


Usage

Step 1 — Prepare data splits (run once)

Starting from a raw pickle file with labels 1 / -1:

python data_prep.py --input data_before.pkl

Options:

Flag Default Description
--input (required) Raw input pickle file
--out_dir . Output directory
--seed 42 Random seed

Outputs: data_train.pkl, data_val.pkl, data_test.pkl (80 / 10 / 10 split).


Step 2 — Train

python rest_net.py

Common flags:

Flag Default Description
--train_pkl data_train.pkl Training set
--val_pkl data_val.pkl Validation set
--test_pkl data_test.pkl Test set
--output_model multi_arm_resnet.pth Checkpoint path
--epochs 20 Training epochs
--batch_size 16 Batch size (reduce if GPU OOM)
--lr 0.001 Base learning rate
--num_workers 2 DataLoader workers
--no_pretrain off Disable ImageNet weights

Example with custom paths:

python rest_net.py \
  --train_pkl splits/data_train.pkl \
  --val_pkl   splits/data_val.pkl \
  --test_pkl  splits/data_test.pkl \
  --epochs 30 --batch_size 8 \
  --output_model checkpoints/best.pth

Running on Google Colab

from google.colab import drive
drive.mount('/content/drive')

!python rest_net.py \
  --train_pkl "/content/drive/MyDrive/ECG/data_train.pkl" \
  --val_pkl   "/content/drive/MyDrive/ECG/data_val.pkl" \
  --test_pkl  "/content/drive/MyDrive/ECG/data_test.pkl"

Outputs

File Description
multi_arm_resnet.pth Best model weights (lowest validation loss)
training_curves.png Loss and accuracy curves (train vs. validation)

Bugs fixed from original notebook

Bug Original Fixed
pd.Series called without importing pandas NameError at label count Replaced with collections.Counter in data_prep.py
criterion undefined in training loop NameError mid-training nn.CrossEntropyLoss() defined before loop
Model never saved torch.save commented out Saved on every validation loss improvement
Colab google.colab import at top level Crashes outside Colab Removed; Colab usage documented in README
tqdm.notebook import crash in terminal ImportError Falls back to standard tqdm
Flat script with no entry point Not importable as module Wrapped in main() with if __name__ == "__main__"

Citation

If you build on this work, please cite:

YHHuan, "RestNet_for_EKG_detection", GitHub, 2024.
https://github.com/YHHuan/RestNet_for_EKG_detection

License

MIT License

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages