forked from aleximmer/Laplace
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlaplace.py
31 lines (25 loc) · 1.13 KB
/
laplace.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from laplace.baselaplace import BaseLaplace
from laplace import *
def Laplace(model, likelihood, subset_of_weights='last_layer', hessian_structure='kron',
*args, **kwargs):
"""Simplified Laplace access using strings instead of different classes.
Parameters
----------
model : torch.nn.Module
likelihood : {'classification', 'regression'}
subset_of_weights : {'last_layer', 'all'}, default='last_layer'
subset of weights to consider for inference
hessian_structure : {'diag', 'kron', 'full'}, default='kron'
structure of the Hessian approximation
Returns
-------
laplace : BaseLaplace
chosen subclass of BaseLaplace instantiated with additional arguments
"""
laplace_map = {subclass._key: subclass for subclass in _all_subclasses(BaseLaplace)
if hasattr(subclass, '_key')}
laplace_class = laplace_map[(subset_of_weights, hessian_structure)]
return laplace_class(model, likelihood, *args, **kwargs)
def _all_subclasses(cls):
return set(cls.__subclasses__()).union(
[s for c in cls.__subclasses__() for s in _all_subclasses(c)])