Skip to content

pypypypy5/vae

Repository files navigation

VAE for MNIST 64x64 RGB

Variational Autoencoder (VAE) for reconstructing MNIST images resized to 64x64 RGB.

Project Structure

vae/
├── src/
│   ├── data/
│   │   ├── __init__.py
│   │   └── dataset.py          # MNIST data loading and preprocessing
│   ├── models/
│   │   ├── __init__.py
│   │   ├── encoder.py          # Encoder network
│   │   ├── decoder.py          # Decoder network
│   │   └── model.py            # VAE model
│   ├── main.py                 # Main training script
│   └── train.py                # Training and evaluation functions
├── tests/
│   ├── conftest.py
│   ├── test_dataset.py         # Dataset tests
│   ├── test_model.py           # Model tests
│   └── test_training.py        # Training tests
├── outputs/
│   ├── best_models/            # Saved model checkpoints
│   └── figures/                # Training curves
├── requirements.txt
├── setup.py
└── README.md

Installation

1. Create Virtual Environment

python3 -m venv venv
source venv/bin/activate  # On Windows: venv\Scripts\activate

2. Install Dependencies

# Using setup.py (auto-detects GPU and installs correct PyTorch)
python setup.py

# Or manually
pip install torch torchvision torchaudio
pip install tqdm matplotlib pytest

Model Architecture

Encoder

  • Input: 64x64x3 RGB images
  • Architecture:
    • Conv2d (3 → 32) + BatchNorm + ReLU + Stride 2
    • Conv2d (32 → 64) + BatchNorm + ReLU + Stride 2
    • Conv2d (64 → 128) + BatchNorm + ReLU + Stride 2
    • Conv2d (128 → 256) + BatchNorm + ReLU + Stride 2
    • Flatten → Linear → μ (mean) and log σ² (log variance)
  • Output: latent vector z ∈ R^64

Decoder

  • Input: latent vector z ∈ R^64
  • Architecture:
    • Linear → Reshape to 256x4x4
    • ConvTranspose2d (256 → 128) + BatchNorm + ReLU
    • ConvTranspose2d (128 → 64) + BatchNorm + ReLU
    • ConvTranspose2d (64 → 32) + BatchNorm + ReLU
    • ConvTranspose2d (32 → 3) + Sigmoid
  • Output: reconstructed 64x64x3 RGB image

Loss Function

  • Total Loss = Reconstruction Loss + KL Divergence
  • Reconstruction Loss: Binary Cross-Entropy between input and reconstruction
  • KL Divergence: -0.5 * Σ(1 + log(σ²) - μ² - σ²)

Usage

Training

cd src
python main.py

Training configuration (in src/main.py):

  • Batch size: 128
  • Learning rate: 0.001
  • Epochs: 30
  • Latent dimension: 64
  • Optimizer: Adam
  • Scheduler: ReduceLROnPlateau

Testing

# Run all unit tests
pytest tests/ -v

# Run quick validation test
python test_run.py

Testing

All tests pass successfully:

tests/test_dataset.py::test_get_mnist_transforms_includes_expected_ops PASSED
tests/test_dataset.py::test_load_mnist_builds_dataloaders_without_touching_disk PASSED
tests/test_model.py::test_vae_forward_shape PASSED
tests/test_model.py::test_vae_output_range PASSED
tests/test_model.py::test_vae_reparameterization PASSED
tests/test_training.py::test_vae_loss_function PASSED
tests/test_training.py::test_train_one_epoch_updates_weights PASSED
tests/test_training.py::test_evaluate_runs_without_grad_updates PASSED

Model Details

  • Total Parameters: 2,171,392
  • Input Shape: (B, 3, 64, 64)
  • Output Shape: (B, 3, 64, 64)
  • Latent Space: (B, 64)

Training Output

During training, you'll see:

  • Batch-wise loss (Total, Reconstruction, KLD)
  • Epoch summary
  • Learning rate updates
  • Best model checkpoints
  • Training curves visualization

Example output:

Epoch [1/30]
  [Batch 1/938] Total Loss: 8064.12, Recon: 8056.45, KLD: 7.67
  ...
  [Train Summary] Avg Loss: 1637.13, Recon: 1540.61, KLD: 96.52
  [Test Summary] Avg Loss: 1384.01, Recon: 1291.43, KLD: 92.58

Outputs

  • Model Checkpoints: outputs/best_models/best_model.pth
  • Training Curves: outputs/figures/training_curves.png
    • Total Loss (Train vs Test)
    • Reconstruction Loss
    • KL Divergence Loss
    • Loss Component Ratio

Requirements

  • Python 3.8-3.12
  • PyTorch >= 2.0
  • torchvision
  • tqdm
  • matplotlib
  • pytest

Device Support

Automatically detects and uses:

  • NVIDIA GPU (CUDA)
  • Apple Silicon (MPS)
  • Intel GPU (XPU)
  • CPU (fallback)

License

MIT License

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors