PyTorch Implemenation of Bayesian Neural Networks trained using Bayes by Backprop (BBB). For more information, see our poster: Bayesian Neural Network Presentation
Paper: Blundell, C., Cornebise, J., Kavukcuoglu, K. and Wierstra, D., 2015, June. Weight uncertainty in neural network. In International Conference on Machine Learning (pp. 1613-1622). PMLR. (http://proceedings.mlr.press/v37/blundell15.html)
Additional approximate inference methods are implemented including:
- Bayes by Backprop - Local Reparameterisation Trick (https://arxiv.org/pdf/1506.02557.pdf): Samples pre-activations instead of weights for lower variance, faster computation and convergence.
- Monte Carlo Dropout (https://arxiv.org/pdf/1506.02142.pdf): Dropout during test-time to generate uncertainty measures, p=0.5.
- Functional Variational Inference (https://arxiv.org/pdf/1903.05779.pdf): Optimises ELBO defined on stochastic processes, i.e. distribution over functions.
Training and evaluation of the model are actioned through main.py, the main entry point. The BNNs and non-Bayesian MLPs are defined in networks.py. Functions required to run each experiment are included in
/regression,/reinforcement_learning, &/classification.
Helper functions are included in utils. data_utils.py for loading data logger_utils.py for logging progress plot_utils.py for plotting and load_model_utils.py for loading trained models..
At run-time, main reads from a model configuration set in config.py. The configurations required to replicate the results of the paper are presented as-is.
e.g. To train any model:
python3 main.py --model [regression|classification|rl]
The scripts weight_pruning.py and compute_ece.py perform post-hoc analysis using saved models.
weight_pruning1) plots the distribution of weights, 2) computes SNR of BNNs, 3) evaluates performance on pruned weights.compute_ece1) computes the expected calibration error (ECE) of trained model, 2) plots reliability diagram.
- Refactor
reg_task.py,class_taskinto base and derived classes / sort out inheritance.