Skip to content

denskrlv/AlphaHMS

Β 
Β 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

75 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

AlphaHMS

AlphaHMS hero image
Photo by Shubham Dhage on Unsplash

⚑ Multi-modal Graph Neural Networks for Harmful Brain Activity Classification

AlphaHMS is a deep-learning pipeline for classifying harmful brain activity from EEG and spectrogram recordings, built around the HMS Harmful Brain Activity Classification Kaggle challenge. The system represents each recording as a pair of temporal graph sequences and learns a joint EEG + spectrogram representation with Graph Attention Networks, BiLSTM temporal encoders, hierarchical regional pooling, and cross-modal attention fusion.

🎯 The six target classes are: Seizure, LPD (Lateralized Periodic Discharges), GPD (Generalized Periodic Discharges), LRDA (Lateralized Rhythmic Delta Activity), GRDA (Generalized Rhythmic Delta Activity), and Other.


✨ Highlights

  • πŸ•ΈοΈ Graph-based EEG modelling β€” each 50 s EEG recording is split into 9 overlapping 10 s windows; nodes are the 19 EEG channels and edges are derived from inter-channel coherence (threshold 0.5).
  • 🌈 Graph-based spectrogram modelling β€” 600 s spectrograms are split into 119 windows over 4 spatial regions (LL, RL, LP, RP) with fixed spatial connectivity.
  • 🧩 Hierarchical pooling by clinical brain regions (Frontal, Central, Parietal, Occipital) before temporal modelling.
  • πŸ”€ Cross-modal fusion with multi-head attention between EEG and spectrogram regional embeddings.
  • ⚑ PyTorch Lightning training with mixed precision (BF16), WandB logging, cross-validation, class-weighted / KL-divergence losses, and early stopping.
  • πŸ” Explainability via GNNExplainer and attention-weight inspection.
  • πŸ“Š Baselines included: EEG-only GNN and a raw-EEG MLP.

πŸ“ Repository Layout

AlphaHMS/
β”œβ”€β”€ configs/                       # OmegaConf YAML configs
β”‚   β”œβ”€β”€ graphs.yaml                # Preprocessing parameters
β”‚   β”œβ”€β”€ model.yaml                 # Multi-modal GNN architecture
β”‚   β”œβ”€β”€ model_eeg.yaml             # EEG-only baseline architecture
β”‚   β”œβ”€β”€ train.yaml                 # Main training config
β”‚   β”œβ”€β”€ train_4fold.yaml           # Cross-validation training
β”‚   β”œβ”€β”€ train_eeg.yaml             # EEG-only baseline training
β”‚   β”œβ”€β”€ training_mlp.yaml          # MLP baseline training
β”‚   β”œβ”€β”€ inference_mlp.yaml         # MLP inference
β”‚   └── smoke_test.yaml            # Quick smoke test
β”œβ”€β”€ notebooks/
β”‚   └── eda.ipynb                  # Exploratory data analysis & preprocessing
β”œβ”€β”€ src/
β”‚   β”œβ”€β”€ data/                      # Datasets, DataModules, graph builders
β”‚   β”‚   β”œβ”€β”€ graph_dataset.py
β”‚   β”‚   β”œβ”€β”€ graph_datamodule.py
β”‚   β”‚   β”œβ”€β”€ baseline_dataset.py
β”‚   β”‚   β”œβ”€β”€ baseline_datamodule.py
β”‚   β”‚   β”œβ”€β”€ raw_eeg_dataset.py
β”‚   β”‚   β”œβ”€β”€ raw_datamodule.py
β”‚   β”‚   β”œβ”€β”€ make_graph_dataset.py  # Build graph dataset from raw data
β”‚   β”‚   └── utils/
β”‚   β”‚       β”œβ”€β”€ eeg_process.py
β”‚   β”‚       └── spectrogram_process.py
β”‚   β”œβ”€β”€ models/
β”‚   β”‚   β”œβ”€β”€ hms_model.py           # Multi-modal model
β”‚   β”‚   β”œβ”€β”€ hms_eeg_model.py       # EEG-only baseline
β”‚   β”‚   β”œβ”€β”€ eeg_mlp.py             # MLP baseline
β”‚   β”‚   β”œβ”€β”€ regularization.py
β”‚   β”‚   β”œβ”€β”€ explainer_wrappers.py
β”‚   β”‚   └── graph_layers/          # GAT, temporal, fusion, pooling, classifier
β”‚   β”œβ”€β”€ lightning_trainer/         # LightningModules (multi-modal, EEG, MLP)
β”‚   β”œβ”€β”€ explainers/
β”‚   β”‚   └── gnn_explainer.py
β”‚   β”œβ”€β”€ train.py                   # Training entrypoint (GNN models)
β”‚   β”œβ”€β”€ train_mlp.py               # Training entrypoint (MLP baseline)
β”‚   β”œβ”€β”€ explain.py                 # GNNExplainer driver
β”‚   β”œβ”€β”€ explain_attention.py       # Attention-weight analysis
β”‚   └── explain_model.py
β”œβ”€β”€ tests/                         # Pytest suite
β”œβ”€β”€ inspect_data.py                # Quick data inspection utility
β”œβ”€β”€ environment.yaml               # Conda environment specification
└── pytest.ini

πŸ› οΈ Installation

1. Create the Conda environment

conda env create -f environment.yaml -y
conda activate graph

2. Install PyTorch with the correct CUDA wheel

The environment file deliberately omits PyTorch so you can match your local CUDA version. For CUDA 12.1:

pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
pip install torcheeg

You will also need PyTorch Geometric matching your PyTorch / CUDA build β€” follow the official install guide.


πŸ“¦ Dataset

1. Download the raw HMS data

You must accept the competition terms on Kaggle first.

mkdir -p data/raw && cd data/raw
kaggle competitions download -c hms-harmful-brain-activity-classification
unzip hms-harmful-brain-activity-classification.zip

Expected layout under data/raw/:

data/raw/
β”œβ”€β”€ train.csv
β”œβ”€β”€ train_eegs/            # parquet files, one per EEG recording
└── train_spectrograms/    # parquet files, one per spectrogram

2. Run the EDA / preprocessing notebook

jupyter execute notebooks/eda.ipynb

3. Build the graph dataset

python src/data/make_graph_dataset.py

This produces one data/processed/patient_{id}.pt per patient plus a metadata.pt index. See src/data/README.md for full details on graph construction, output format, and memory requirements.


πŸš€ Training

All training scripts log to Weights & Biases; run wandb login once before starting.

Multi-modal GNN (main model)

python src/train.py --train-config configs/train.yaml

Cross-validation (5-fold)

python src/train.py --train-config configs/train_4fold.yaml

EEG-only GNN baseline

python src/train.py --train-config configs/train_eeg.yaml

MLP baseline (raw EEG)

python src/train_mlp.py --config configs/training_mlp.yaml

Quick smoke test

python src/train.py --train-config configs/smoke_test.yaml

Evaluation runs automatically at the end of training. Checkpoints are written to the directory specified in the config.


πŸ—οΈ Model Architecture

The multi-modal model (see configs/model.yaml) is composed of:

  1. EEG encoder β€” 2-layer multi-head GAT (64-dim, 4 heads) with coherence edge weights β†’ hierarchical regional pooling β†’ 2-layer BiLSTM (128-dim, bidirectional).
  2. Spectrogram encoder β€” 2-layer GAT (64-dim, 4 heads) over the 4 spatial regions β†’ BiLSTM (128-dim, bidirectional).
  3. Cross-modal fusion β€” multi-head cross-attention (8 heads, 256-dim) between EEG and spectrogram regional embeddings, with attention pooling over regions.
  4. Classifier β€” MLP with hidden sizes [256, 128], ELU activations, dropout 0.3 β†’ 6-class softmax.

Loss defaults to KL divergence against the soft expert-vote distribution; class weighting, graph-Laplacian regularization, and edge-weight penalties are all configurable.


πŸ”¬ Explainability

# GNNExplainer over a trained checkpoint
python src/explain.py

# Attention-weight visualisation
python src/explain_attention.py

πŸ§ͺ Testing

pytest

The test suite covers the data module, preprocessing pipeline, multiprocessing, regularization, checkpoint resume, and spectrogram processing.


πŸ’» Hardware Notes

  • Training was developed on H200 / RTX-class GPUs with BF16 mixed precision.
  • Preprocessing is CPU-bound and benefits from build_workers set to your physical core count.
  • Recommended: β‰₯ 16 GB RAM for preprocessing, β‰₯ 1 modern CUDA GPU for training.

πŸ“Ž Citation

If you use AlphaHMS in your research or build upon it, please cite this repository:

@software{krylov2025alphahms,
  author       = {Denis Krylov, Samuel Goldie, Alberto Pasinato, Serkan Akin, Leonardo Lago},
  title        = {{AlphaHMS}: Multi-modal Graph Neural Networks for Harmful Brain Activity Classification},
  year         = {2025},
  institution  = {Delft University of Technology},
  url          = {https://github.com/deniskrylov/AlphaHMS},
  note         = {Built for the HMS Harmful Brain Activity Classification challenge (Kaggle)}
}

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages

  • Python 90.4%
  • Jupyter Notebook 9.6%