Skip to content

MarkAppprogrammer/javaMNIST

Repository files navigation

javaMNIST

This project is a from-scratch (no ML libraries) neural network implementation in Java, trained and evaluated on the classic MNIST handwritten digits dataset.

How the code is structured

Main.java contains the neural network implementation:

  • ActivationFunction: sigmoid, sigmoid derivative, ReLU/Leaky ReLU, and softMax
  • LossFunction: cross-entropy loss (for softmax classification)
  • Neuron and Layer: store weights/biases for each layer
  • NeuralNetwork: forward pass, backpropagation, training, prediction, and testing
  • public class Main: CLI menu (1 = Train, 2 = Test)

PreProcess.java contains the dataset and math utilities:

  • Reads MNIST .idx files for images/labels (train-* and t10k-*)
  • Normalizes pixel values
  • Converts labels to one-hot vectors
  • Provides matrix helpers used by the network (dot, transpose, sumSecondAxis, etc.)

The dataset files live in mnist-dataset/ in this repo:

  • train-images.idx3-ubyte, train-labels.idx1-ubyte
  • t10k-images.idx3-ubyte, t10k-labels.idx1-ubyte

Setup / Install

1. Install Java

This code is compiled to Java 8 bytecode for compatibility with older Java runtimes:

  • Ensure you have a JDK (not just a JRE) installed
  • Recommended: Java 8 (or newer tooling compiled with --release 8)

2. Get the MNIST dataset

Download the MNIST dataset (the same *.idx files referenced by the code) from Kaggle:

Place the dataset files into the mnist-dataset/ folder so the paths match what the code loads:

  • mnist-dataset/train-images.idx3-ubyte
  • mnist-dataset/train-labels.idx1-ubyte
  • mnist-dataset/t10k-images.idx3-ubyte
  • mnist-dataset/t10k-labels.idx1-ubyte

Build and Run (macOS/Linux)

From the repo root (javaMNIST/):

mkdir -p out
rm -rf out

# Compile to Java 8 bytecode (prevents UnsupportedClassVersionError)
javac --release 8 -d out Main.java PreProcess.java

# Run
java -cp out Final.Main

The program will prompt you:

  • 1 to train (then it will run a test pass on the trained weights)
  • 2 to test (runs inference with the current in-memory weights; if you didn’t train in the same run, this will be near-random accuracy)

Accuracy improvements vs the original version

Key improvements made to get the network working well on MNIST:

  • Correct MNIST input normalization: switched to global scaling (pixel / 255.0) instead of per-image min/max scaling
  • Better hidden-layer nonlinearity: switched hidden activation from sigmoid to Leaky ReLU and updated backprop accordingly
  • Better training procedure: replaced full-batch gradient descent with mini-batch training plus momentum
  • Increased model capacity: changed network shape from 784 -> 10 -> 10 to 784 -> 64 -> 10
  • Added/finished test-time evaluation on t10k-* so we can report accuracy properly

Example results observed with the current default settings (train on 4000 samples, test on 2000 samples, epochs = 401):

  • Training accuracy reaches ~100% on the 4000 training subset
  • Test accuracy: about 90.65%

Before the accuracy improvements listed above, training on 4000 images typically plateaued in the low/mid 70% range (for example ~73.7% at epoch 400), and t10k evaluation was not implemented, so it was harder to tell how well the model generalized.

About

implementation of basic neural network in java. Uses classic MNIST dataset.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages