Skip to content

Commit d310b39

Browse files
MathiasMNilsenMathias Methlie NilsenMathias Methlie Nilsen
authored
Fixed class inheritance for popt ensembles (#100)
* update to TrustRegion * some design changes to TrustRegion * decoupled GenOpt from Ensemble * decoupled GenOpt from Ensemble * Cleaned up code duplication and renamed some stuff * comments --------- Co-authored-by: Mathias Methlie Nilsen <[email protected]> Co-authored-by: Mathias Methlie Nilsen <[email protected]>
1 parent 8345980 commit d310b39

File tree

5 files changed

+78
-224
lines changed

5 files changed

+78
-224
lines changed

popt/loop/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
"""Main loop for running optimization."""
1+
"""Main loop for running optimization."""

popt/loop/ensemble_base.py

Lines changed: 66 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
# Internal imports
99
from popt.misc_tools import optim_tools as ot
1010
from pipt.misc_tools import analysis_tools as at
11-
from ensemble.ensemble import Ensemble as PETEnsemble
11+
from ensemble.ensemble import Ensemble as SupEnsemble
1212
from simulator.simple_models import noSimulation
1313

14-
class EnsembleOptimizationBaseClass(PETEnsemble):
14+
__all__ = ['EnsembleOptimizationBaseClass']
15+
16+
class EnsembleOptimizationBaseClass(SupEnsemble):
1517
'''
1618
Base class for the popt ensemble
1719
'''
@@ -33,61 +35,64 @@ def __init__(self, options, simulator, objective):
3335
else:
3436
sim = simulator
3537

36-
# Initialize PETEnsemble
38+
# Initialize the PET Ensemble
3739
super().__init__(options, sim)
3840

3941
# Unpack some options
4042
self.save_prediction = options.get('save_prediction', None)
4143
self.num_models = options.get('num_models', 1)
4244
self.transform = options.get('transform', False)
4345
self.num_samples = self.ne
44-
45-
# Define some variables
46+
47+
# Set objective function (callable)
48+
self.obj_func = objective
49+
self.state_func_values = None
50+
self.ens_func_values = None
51+
52+
# Initialize prior
53+
self._initialize_state_info() # Initialize cov, bounds, and state
54+
self._scale_state() # Scale self.state to [0, 1] if transform is True
55+
56+
def _initialize_state_info(self):
57+
'''
58+
Initialize covariance and bounds based on prior information.
59+
'''
60+
self.cov = np.array([])
4661
self.lb = []
4762
self.ub = []
4863
self.bounds = []
49-
self.cov = np.array([])
50-
51-
# Get bounds and varaince, and initialize state
64+
5265
for key in self.prior_info.keys():
5366
variable = self.prior_info[key]
54-
67+
5568
# mean
5669
self.state[key] = np.asarray(variable['mean'])
5770

5871
# Covariance
5972
dim = self.state[key].size
60-
cov = variable['variance']*np.ones(dim)
61-
73+
var = variable['variance']*np.ones(dim)
74+
6275
if 'limits' in variable.keys():
6376
lb, ub = variable['limits']
64-
self.lb(lb)
65-
self.ub(ub)
66-
67-
# transform cov to [0, 1] if transform is True
77+
self.lb.append(lb)
78+
self.ub.append(ub)
79+
80+
# transform var to [0, 1] if transform is True
6881
if self.transform:
69-
cov = np.clip(cov/(ub - lb)**2, 0, 1, out=cov)
82+
var = var/(ub - lb)**2
83+
var = np.clip(var, 0, 1, out=var)
7084
self.bounds += dim*[(0, 1)]
7185
else:
7286
self.bounds += dim*[(lb, ub)]
7387
else:
7488
self.bounds += dim*[(None, None)]
7589

7690
# Add to covariance
77-
self.cov = np.append(self.cov, cov)
78-
91+
self.cov = np.append(self.cov, var)
92+
self.dim = self.cov.shape[0]
93+
7994
# Make cov full covariance matrix
8095
self.cov = np.diag(self.cov)
81-
82-
# Scale the state to [0, 1] if transform is True
83-
self._scale_state()
84-
85-
# Set objective function (callable)
86-
self.obj_func = objective
87-
88-
# Objective function values
89-
self.state_func_values = None
90-
self.ens_func_values = None
9196

9297
def get_state(self):
9398
"""
@@ -98,6 +103,15 @@ def get_state(self):
98103
"""
99104
return ot.aug_optim_state(self.state, list(self.state.keys()))
100105

106+
def get_cov(self):
107+
"""
108+
Returns
109+
-------
110+
cov : numpy.ndarray
111+
Covariance matrix, shape (number of controls, number of controls)
112+
"""
113+
return self.cov
114+
101115
def vec_to_state(self, x):
102116
"""
103117
Converts a control vector to the internal state representation.
@@ -114,7 +128,7 @@ def get_bounds(self):
114128

115129
return self.bounds
116130

117-
def function(self, x, *args):
131+
def function(self, x, *args, **kwargs):
118132
"""
119133
This is the main function called during optimization.
120134
@@ -130,29 +144,41 @@ def function(self, x, *args):
130144
"""
131145
self._aux_input()
132146

133-
if len(x.shape) == 1:
134-
self.ne = self.num_models
135-
else:
136-
self.ne = x.shape[1]
147+
# check for ensmble
148+
if len(x.shape) == 1: self.ne = self.num_models
149+
else: self.ne = x.shape[1]
137150

138-
# convert x to state
139-
self.state = self.vec_to_state(x) # go from nparray to dict
151+
# convert x (nparray) to state (dict)
152+
self.state = self.vec_to_state(x)
140153

141154
# run the simulation
142155
self._invert_scale_state() # ensure that state is in [lb,ub]
156+
self._set_multilevel_state(self.state, x) # set multilevel state if applicable
143157
run_success = self.calc_prediction(save_prediction=self.save_prediction) # calculate flow data
158+
self._set_multilevel_state(self.state, x) # For some reason this has to be done again after calc_prediction
144159
self._scale_state() # scale back to [0, 1]
160+
161+
# Evaluate the objective function
145162
if run_success:
146-
func_values = self.obj_func(self.pred_data, self.sim.input_dict, self.sim.true_order)
163+
func_values = self.obj_func(
164+
self.pred_data,
165+
input_dict=self.sim.input_dict,
166+
true_order=self.sim.true_order,
167+
**kwargs
168+
)
147169
else:
148170
func_values = np.inf # the simulations have crashed
149171

150-
if len(x.shape) == 1:
151-
self.state_func_values = func_values
152-
else:
153-
self.ens_func_values = func_values
172+
if len(x.shape) == 1: self.state_func_values = func_values
173+
else: self.ens_func_values = func_values
154174

155175
return func_values
176+
177+
def _set_multilevel_state(self, state, x):
178+
if 'multilevel' in self.keys_en.keys() and len(x.shape) > 1:
179+
en_size = ot.get_list_element(self.keys_en['multilevel'], 'en_size')
180+
self.state = ot.toggle_ml_state(self.state, en_size)
181+
156182

157183
def _aux_input(self):
158184
"""

0 commit comments

Comments
 (0)