Skip to content

Bayesian NNs#3

Merged
nickwimer merged 10 commits into
mainfrom
bayesian_nn
Jul 31, 2025
Merged

Bayesian NNs#3
nickwimer merged 10 commits into
mainfrom
bayesian_nn

Conversation

@nickwimer
Copy link
Copy Markdown
Collaborator

No description provided.

@nickwimer nickwimer requested a review from Copilot July 31, 2025 15:31
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces Bayesian Neural Networks (BNNs) to the PT-MELT framework along with enhancements for hyperparameter tuning capabilities. The implementation adds probabilistic layers and uncertainty quantification features.

Key changes include:

  • Added new BayesianNeuralNetwork model with Bayesian layers using flipout technique
  • Enhanced hyperparameter tuning with new dependencies and model builder utility
  • Improved mixture density loss computation and visualization flexibility

Reviewed Changes

Copilot reviewed 8 out of 9 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
setup.py Added hyperparameter tuning dependencies and version bump
ptmelt/utils/hp_tuning.py New model builder utility for hyperparameter optimization
ptmelt/utils/visualization.py Made R² and RMSE parameters optional in plotting function
ptmelt/nn_utils.py Added new activation functions and loss function utilities
ptmelt/models.py Added BayesianNeuralNetwork class and refactored training loop
ptmelt/losses.py Improved mixture density loss with numerical stability
ptmelt/layers.py Added Bayesian flipout layer implementation
ptmelt/blocks.py Added BayesianBlock and enhanced output layers with seed support
Comments suppressed due to low confidence (1)

ptmelt/layers.py:273

  • Commented out code should be removed rather than left in the codebase. If num_points is not needed, remove the parameter and the comment.
        self.register_buffer("dmax", torch.tensor(dmax))

Comment thread ptmelt/utils/hp_tuning.py Outdated
Comment thread ptmelt/utils/hp_tuning.py Outdated
Comment thread ptmelt/models.py Outdated
Comment thread ptmelt/models.py Outdated
Comment thread ptmelt/blocks.py
Comment thread ptmelt/blocks.py Outdated
Comment thread ptmelt/models.py
Comment thread ptmelt/models.py
@nickwimer nickwimer self-assigned this Jul 31, 2025
Copy link
Copy Markdown
Collaborator Author

@nickwimer nickwimer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorporated appropriate changes from review. Looks good now.

Comment thread ptmelt/utils/hp_tuning.py Outdated

return model, optimizer, criterion

elif arch_type == "resnet":
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixing this in next commit...

@nickwimer nickwimer merged commit d3d04b7 into main Jul 31, 2025
1 check passed
@nickwimer nickwimer deleted the bayesian_nn branch July 31, 2025 18:35
nickwimer added a commit that referenced this pull request Oct 8, 2025
* Bayesian NNs (#3)

* adding support for hyperparameter tuning using Ray

* remove manditory r2 and rmse from plot text

* WIP; initial code for ptBNN replicating tf flipout

* adding in updates to BNN working tests for iaps...

* adjusting the clamping for the MDN output to try to avoid NaNs

* updates to MDN output for stability

* adding in seed for reproduction testing

* cleaning up before PR

* fixing typos and adding in conditions for partial bayes blocks

* removing pass for unsupported architectures...todo to fully implement

* RNN models and LR Schedulers (#4)

* moving utility functions into class files...might deprecate soon

* making the mixture density loss have mse regularization

* adding in schedulers and early stopping

* fixing mse addition loss term for MDNs

* updating regression notebook

* adding in support for LSTM model from time series modeling work

* cleaning up old commented code

* removing commented code
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants