Variational Autoencoder (VAE) for reconstructing MNIST images resized to 64x64 RGB.
vae/
├── src/
│ ├── data/
│ │ ├── __init__.py
│ │ └── dataset.py # MNIST data loading and preprocessing
│ ├── models/
│ │ ├── __init__.py
│ │ ├── encoder.py # Encoder network
│ │ ├── decoder.py # Decoder network
│ │ └── model.py # VAE model
│ ├── main.py # Main training script
│ └── train.py # Training and evaluation functions
├── tests/
│ ├── conftest.py
│ ├── test_dataset.py # Dataset tests
│ ├── test_model.py # Model tests
│ └── test_training.py # Training tests
├── outputs/
│ ├── best_models/ # Saved model checkpoints
│ └── figures/ # Training curves
├── requirements.txt
├── setup.py
└── README.md
python3 -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate# Using setup.py (auto-detects GPU and installs correct PyTorch)
python setup.py
# Or manually
pip install torch torchvision torchaudio
pip install tqdm matplotlib pytest- Input: 64x64x3 RGB images
- Architecture:
- Conv2d (3 → 32) + BatchNorm + ReLU + Stride 2
- Conv2d (32 → 64) + BatchNorm + ReLU + Stride 2
- Conv2d (64 → 128) + BatchNorm + ReLU + Stride 2
- Conv2d (128 → 256) + BatchNorm + ReLU + Stride 2
- Flatten → Linear → μ (mean) and log σ² (log variance)
- Output: latent vector z ∈ R^64
- Input: latent vector z ∈ R^64
- Architecture:
- Linear → Reshape to 256x4x4
- ConvTranspose2d (256 → 128) + BatchNorm + ReLU
- ConvTranspose2d (128 → 64) + BatchNorm + ReLU
- ConvTranspose2d (64 → 32) + BatchNorm + ReLU
- ConvTranspose2d (32 → 3) + Sigmoid
- Output: reconstructed 64x64x3 RGB image
- Total Loss = Reconstruction Loss + KL Divergence
- Reconstruction Loss: Binary Cross-Entropy between input and reconstruction
- KL Divergence: -0.5 * Σ(1 + log(σ²) - μ² - σ²)
cd src
python main.pyTraining configuration (in src/main.py):
- Batch size: 128
- Learning rate: 0.001
- Epochs: 30
- Latent dimension: 64
- Optimizer: Adam
- Scheduler: ReduceLROnPlateau
# Run all unit tests
pytest tests/ -v
# Run quick validation test
python test_run.pyAll tests pass successfully:
tests/test_dataset.py::test_get_mnist_transforms_includes_expected_ops PASSED
tests/test_dataset.py::test_load_mnist_builds_dataloaders_without_touching_disk PASSED
tests/test_model.py::test_vae_forward_shape PASSED
tests/test_model.py::test_vae_output_range PASSED
tests/test_model.py::test_vae_reparameterization PASSED
tests/test_training.py::test_vae_loss_function PASSED
tests/test_training.py::test_train_one_epoch_updates_weights PASSED
tests/test_training.py::test_evaluate_runs_without_grad_updates PASSED- Total Parameters: 2,171,392
- Input Shape: (B, 3, 64, 64)
- Output Shape: (B, 3, 64, 64)
- Latent Space: (B, 64)
During training, you'll see:
- Batch-wise loss (Total, Reconstruction, KLD)
- Epoch summary
- Learning rate updates
- Best model checkpoints
- Training curves visualization
Example output:
Epoch [1/30]
[Batch 1/938] Total Loss: 8064.12, Recon: 8056.45, KLD: 7.67
...
[Train Summary] Avg Loss: 1637.13, Recon: 1540.61, KLD: 96.52
[Test Summary] Avg Loss: 1384.01, Recon: 1291.43, KLD: 92.58
- Model Checkpoints:
outputs/best_models/best_model.pth - Training Curves:
outputs/figures/training_curves.png- Total Loss (Train vs Test)
- Reconstruction Loss
- KL Divergence Loss
- Loss Component Ratio
- Python 3.8-3.12
- PyTorch >= 2.0
- torchvision
- tqdm
- matplotlib
- pytest
Automatically detects and uses:
- NVIDIA GPU (CUDA)
- Apple Silicon (MPS)
- Intel GPU (XPU)
- CPU (fallback)
MIT License