This repository contains an implementation of a Siamese Neural Network for comparing pairs of MNIST handwritten digits. The model is trained to determine whether two digit images represent the same digit or different digits.
Siamese networks are neural network architectures that contain two identical subnetworks with shared weights. These networks learn to find the similarity between inputs by comparing their feature representations. In this implementation, we use a Siamese network to compare MNIST digit images and determine if they represent the same digit.
- PyTorch implementation of a Siamese Neural Network
- Data preprocessing for creating pairs from the MNIST dataset
- Training and evaluation scripts
- Visualization tools for examining model predictions
- Alternative Keras implementation in Jupyter notebook format
# Clone the repository
git clone <repository-url>
cd siamese
# Install required packages
pip install torch torchvision matplotlib numpy
main.py
- Main training scripttest.py
- Testing and visualization utilitiespre_process.py
- Data preprocessing and dataset creationnet.py
- Siamese network architecture definitionSiamese_keras.ipynb
- Alternative implementation using Keras
python main.py
This will:
- Download the MNIST dataset (if not already present)
- Create pairs of similar and dissimilar digit images
- Train the Siamese network for the specified number of epochs
- Save the trained model to
siamese_model.pth
python test.py
This will:
- Load the saved model
- Evaluate its performance on the test dataset
- Visualize example predictions with
plot_checker
function
The Siamese network processes pairs of MNIST digit images through identical neural networks with shared weights. The model then computes the similarity between the resulting feature vectors to determine if the input images represent the same digit.
Key components:
SiameseNet
: The neural network architecture with shared weightsSiamenseDataset
: Custom dataset class for handling image pairsmake_pairs
: Function to create pairs of same/different digit imagesBCEWithLogitsLoss
: Binary cross-entropy loss for training
After running the test script, visualizations will show pairs of digit images along with the model's prediction of whether they represent the same digit or different digits.
[Add your license information here]