Skip to content

Commit 1b0798c

Browse files
MathiasMNilsenMathias Methlie NilsenMathias Methlie Nilsenrolfjl
authored
fixed some issues (#106)
* update to TrustRegion * some design changes to TrustRegion * decoupled GenOpt from Ensemble * decoupled GenOpt from Ensemble * Cleaned up code duplication and renamed some stuff * comments * improved and cleaned up LineSearch * added Newton-CG to LineSearch * improved logging for LineSearch * dummy message for github check * empty commit * udpate BFGS to skip update if negetive curvature * made input more elegant and cleaned up popt structure * changed @cache to @lru_cache * branch commit * branch commit * re-added stuff that got deleted * fixed what got lost * fixed what got lost * fixed confusing naming of input keys * more input stuff * added extract_maxiter to extract_tools * assertion fix * fixed input for sim options * Update extract_tools.py --------- Co-authored-by: Mathias Methlie Nilsen <[email protected]> Co-authored-by: Mathias Methlie Nilsen <[email protected]> Co-authored-by: Rolf J. Lorentzen <[email protected]>
1 parent 994e601 commit 1b0798c

File tree

14 files changed

+554
-309
lines changed

14 files changed

+554
-309
lines changed

ensemble/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
"""Multiple realisations management."""
1+
"""Multiple realisations management."""

ensemble/ensemble.py

Lines changed: 17 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@
1818

1919
# Internal imports
2020
import pipt.misc_tools.analysis_tools as at
21+
import pipt.misc_tools.extract_tools as extract
2122
from geostat.decomp import Cholesky # Making realizations
2223
from pipt.misc_tools import cov_regularization
2324
from pipt.misc_tools import wavelet_tools as wt
2425
from misc import read_input_csv as rcsv
2526
from misc.system_tools.environ_var import OpenBlasSingleThread # Single threaded OpenBLAS runs
2627

2728

29+
2830
class Ensemble:
2931
"""
3032
Class for organizing misc. variables and simulator for an ensemble-based inversion run. Here, the forecast step
@@ -56,11 +58,13 @@ def __init__(self, keys_en, sim, redund_sim=None):
5658
self.aux_input = None
5759

5860
# Setup logger
59-
logging.basicConfig(level=logging.INFO,
60-
filename='pet_logger.log',
61-
filemode='w',
62-
format='%(asctime)s : %(levelname)s : %(name)s : %(message)s',
63-
datefmt='%Y-%m-%d %H:%M:%S')
61+
logging.basicConfig(
62+
level=logging.INFO,
63+
filename='pet_logger.log',
64+
filemode='w',
65+
format='%(asctime)s : %(levelname)s : %(name)s : %(message)s',
66+
datefmt='%Y-%m-%d %H:%M:%S'
67+
)
6468
self.logger = logging.getLogger('PET')
6569

6670
# Check if folder contains any En_ files, and remove them!
@@ -117,7 +121,7 @@ def __init__(self, keys_en, sim, redund_sim=None):
117121
self.disable_tqdm = False
118122

119123
# extract information that is given for the prior model
120-
self.prior_info = self._extract_prior_info()
124+
self.prior_info = extract.extract_prior_info(self.keys_en)
121125

122126
# Calculate initial ensemble if IMPORTSTATICVAR has not been given in init. file.
123127
# Prior info. on state variables must be given by PRIOR_<STATICVAR-name> keyword.
@@ -143,7 +147,11 @@ def __init__(self, keys_en, sim, redund_sim=None):
143147
print('\033[1;33mInput states have different ensemble size\033[1;m')
144148
sys.exit(1)
145149
self.ne = min(tmp_ne)
146-
self._ext_ml_info()
150+
151+
if 'multilevel' in self.keys_en:
152+
ml_info = extract.extract_multilevel_info(self.keys_en)
153+
self.multilevel, self.tot_level, self.ml_ne, self.ML_error_corr, self.error_comp_scheme, self.ML_corr_done = ml_info
154+
#self._ext_ml_info()
147155

148156
def _ext_ml_info(self):
149157
'''
@@ -172,117 +180,7 @@ def _ext_ml_info(self):
172180
self.error_comp_scheme = self.keys_en['multilevel'][i][2]
173181
self.ML_corr_done = False
174182

175-
def _extract_prior_info(self) -> dict:
176-
'''
177-
Extract prior information on STATE from keyword(s) PRIOR_<STATE entries>.
178-
'''
179-
180-
# Get state names as list
181-
state_names = self.keys_en['state']
182-
if not isinstance(state_names, list): state_names = [state_names]
183-
184-
# Check if PRIOR_<state names> exists for each entry in state
185-
for name in state_names:
186-
assert f'prior_{name}' in self.keys_en, \
187-
'PRIOR_{0} is missing! This keyword is needed to make initial ensemble for {0} entered in ' \
188-
'STATE'.format(name.upper())
189-
190-
# define dict to store prior information in
191-
prior_info = {name: None for name in state_names}
192-
193-
# loop over state priors
194-
for name in state_names:
195-
prior = self.keys_en[f'prior_{name}']
196-
197-
# Check if is a list (old way)
198-
if isinstance(prior, list):
199-
# list of lists - old way of inputting prior information
200-
prior_dict = {}
201-
for i, opt in enumerate(list(zip(*prior))[0]):
202-
if opt == 'limits':
203-
prior_dict[opt] = prior[i][1:]
204-
else:
205-
prior_dict[opt] = prior[i][1]
206-
prior = prior_dict
207-
else:
208-
assert isinstance(prior, dict), 'PRIOR_{0} must be a dictionary or list of lists!'.format(name.upper())
209-
210-
211-
# load mean if in file
212-
if isinstance(prior['mean'], str):
213-
assert prior['mean'].endswith('.npz'), 'File name does not end with \'.npz\'!'
214-
load_file = np.load(prior['mean'])
215-
assert len(load_file.files) == 1, \
216-
'More than one variable located in {0}. Only the mean vector can be stored in the .npz file!' \
217-
.format(prior['mean'])
218-
prior['mean'] = load_file[load_file.files[0]]
219-
else: # Single number inputted, make it a list if not already
220-
if not isinstance(prior['mean'], list):
221-
prior['mean'] = [prior['mean']]
222-
223-
# loop over keys in prior
224-
for key in prior.keys():
225-
# ensure that entry is a list
226-
if (not isinstance(prior[key], list)) and (key != 'mean'):
227-
prior[key] = [prior[key]]
228-
229-
# change the name of some keys
230-
prior['variance'] = prior.pop('var', None)
231-
prior['corr_length'] = prior.pop('range', None)
232-
233-
# process grid
234-
if 'grid' in prior:
235-
grid_dim = prior['grid']
236-
237-
# check if 3D-grid
238-
if (len(grid_dim) == 3) and (grid_dim[2] > 1):
239-
nz = int(grid_dim[2])
240-
prior['nz'] = nz
241-
prior['nx'] = int(grid_dim[0])
242-
prior['ny'] = int(grid_dim[1])
243-
244-
245-
# Check mean when values have been inputted directly (not when mean has been loaded)
246-
mean = prior['mean']
247-
if isinstance(mean, list) and len(mean) < nz:
248-
# Check if it is more than one entry and give error
249-
assert len(mean) == 1, \
250-
'Information from MEAN has been given for {0} layers, whereas {1} is needed!' \
251-
.format(len(mean), nz)
252-
253-
# Only 1 entry; copy this to all layers
254-
print(
255-
'\033[1;33mSingle entry for MEAN will be copied to all {0} layers\033[1;m'.format(nz))
256-
prior['mean'] = mean * nz
257-
258-
#check if info. has been given on all layers. In the case it has not been given, we just copy the info. given.
259-
for key in ['vario', 'variance', 'aniso', 'angle', 'corr_length']:
260-
if key in prior.keys():
261-
val = prior[key]
262-
if len(val) < nz:
263-
# Check if it is more than one entry and give error
264-
assert len(val) == 1, \
265-
'Information from {0} has been given for {1} layers, whereas {2} is needed!' \
266-
.format(key.upper(), len(val), nz)
267-
268-
# Only 1 entry; copy this to all layers
269-
print(
270-
'\033[1;33mSingle entry for {0} will be copied to all {1} layers\033[1;m'.format(key.upper(), nz))
271-
prior[key] = val * nz
272-
273-
else:
274-
prior['nx'] = int(grid_dim[0])
275-
prior['ny'] = int(grid_dim[1])
276-
prior['nz'] = 1
277-
278-
prior.pop('grid', None)
279-
280-
# add prior to prior_info
281-
prior_info[name] = prior
282-
283-
return prior_info
284-
285-
183+
286184
def gen_init_ensemble(self):
287185
"""
288186
Generate the initial ensemble of (joint) state vectors using the GeoStat class in the "geostat" package.
@@ -353,6 +251,7 @@ def gen_init_ensemble(self):
353251
# Save the ensemble for later inspection
354252
np.savez('prior.npz', **self.state)
355253

254+
356255
def get_list_assim_steps(self):
357256
"""
358257
Returns list of assimilation steps. Useful in a 'loop'-script.

input_output/read_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,13 @@ def ndarray_constructor(loader, node):
5555
keys_en = y['ensemble']
5656
check_mand_keywords_en(keys_en)
5757
else:
58-
keys_en = None
58+
keys_en = {}
5959

6060
if 'optim' in y.keys():
6161
keys_pr = y['optim']
6262
check_mand_keywords_opt(keys_pr)
6363
elif 'dataassim' in y.keys():
64-
keys_pr = y['datasssim']
64+
keys_pr = y['dataassim']
6565
check_mand_keywords_da(keys_pr)
6666
else:
6767
raise KeyError
@@ -109,7 +109,7 @@ def read_toml(init_file):
109109
keys_en = t['ensemble']
110110
check_mand_keywords_en(keys_en)
111111
else:
112-
keys_en = None
112+
keys_en = {}
113113
if 'optim' in t.keys():
114114
keys_pr = t['optim']
115115
check_mand_keywords_opt(keys_pr)

pipt/loop/assimilation.py

Lines changed: 17 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
from importlib import import_module
1717

1818
# Internal imports
19-
from pipt.misc_tools import qaqc_tools
19+
from pipt.misc_tools.qaqc_tools import QAQC
2020
from pipt.loop.ensemble import Ensemble
2121
from misc.system_tools.environ_var import OpenBlasSingleThread
2222
from pipt.misc_tools import analysis_tools as at
23+
import pipt.misc_tools.extract_tools as extract
2324

2425

2526
class Assimilate:
@@ -50,7 +51,7 @@ def __init__(self, ensemble: Ensemble):
5051
if hasattr(ensemble, 'max_iter'):
5152
self.max_iter = self.ensemble.max_iter
5253
else:
53-
self.max_iter = self._ext_max_iter()
54+
self.max_iter = extract.extract_maxiter(self.ensemble.keys_da)
5455

5556
# Within variables
5657
self.why_stop = None # Output of why iter. loop stopped
@@ -83,15 +84,20 @@ def run(self):
8384
success_iter = True
8485

8586
# Initiallize progressbar
86-
pbar_out = tqdm(total=self.max_iter,
87-
desc='Iterations (Obj. func. val: )', position=0)
87+
pbar_out = tqdm(total=self.max_iter, desc='Iterations (Obj. func. val: )', position=0)
8888

8989
# Check if we want to perform a Quality Assurance of the forecast
9090
qaqc = None
91-
if 'qa' in self.ensemble.sim.input_dict or 'qc' in self.ensemble.keys_da:
92-
qaqc = qaqc_tools.QAQC({**self.ensemble.keys_da, **self.ensemble.sim.input_dict},
93-
self.ensemble.obs_data, self.ensemble.datavar, self.ensemble.logger,
94-
self.ensemble.prior_info, self.ensemble.sim, self.ensemble.prior_state)
91+
if ('qa' in self.ensemble.sim.input_dict) or ('qc' in self.ensemble.keys_da):
92+
qaqc = QAQC(
93+
self.ensemble.keys_da|self.ensemble.sim.input_dict,
94+
self.ensemble.obs_data,
95+
self.ensemble.datavar,
96+
self.ensemble.logger,
97+
self.ensemble.prior_info,
98+
self.ensemble.sim,
99+
self.ensemble.prior_state
100+
)
95101

96102
# Run a while loop until max. iterations or convergence is reached
97103
while self.ensemble.iteration < self.max_iter and conv is False:
@@ -107,20 +113,17 @@ def run(self):
107113

108114
if 'qa' in self.ensemble.keys_da: # Check if we want to perform a Quality Assurance of the forecast
109115
# set updated prediction, state and lam
110-
qaqc.set(self.ensemble.pred_data,
111-
self.ensemble.state, self.ensemble.lam)
116+
qaqc.set(self.ensemble.pred_data, self.ensemble.state, self.ensemble.lam)
112117
# Level 1,2 all data, and subspace
113118
qaqc.calc_mahalanobis((1, 'time', 2, 'time', 1, None, 2, None))
114119
qaqc.calc_coverage() # Compute data coverage
115-
qaqc.calc_kg({'plot_all_kg': True, 'only_log': False,
116-
'num_store': 5}) # Compute kalman gain
120+
qaqc.calc_kg({'plot_all_kg': True, 'only_log': False, 'num_store': 5}) # Compute kalman gain
117121

118122
success_iter = True
119123

120124
# always store prior forcast, unless specifically told not to
121125
if 'nosave' not in self.ensemble.keys_da:
122-
np.savez('prior_forecast.npz', **
123-
{'pred_data': self.ensemble.pred_data})
126+
np.savez('prior_forecast.npz', pred_data=self.ensemble.pred_data)
124127

125128
# For the remaining iterations we start by applying the analysis and finish by running the forecast
126129
else:
@@ -279,48 +282,6 @@ def remove_outliers(self):
279282
self.ensemble.pred_data[i][el][:, index] = deepcopy(
280283
self.ensemble.pred_data[i][el][:, new_index])
281284

282-
def _ext_max_iter(self):
283-
"""
284-
Extract max iterations from ITERATION keyword in DATAASSIM part (mandatory keyword for iteration loops).
285-
286-
Parameters
287-
----------
288-
keys_da : dict
289-
A dictionary containing all keywords from DATAASSIM part.
290-
291-
- 'iteration' : object
292-
Information for iterative methods.
293-
294-
Returns
295-
-------
296-
max_iter : int
297-
The maximum number of iterations allowed before abort.
298-
299-
Changelog
300-
---------
301-
- ST 7/6-16
302-
"""
303-
if 'iteration' in self.ensemble.keys_da:
304-
iter_opts = dict(self.ensemble.keys_da['iteration'])
305-
# Check if 'max_iter' has been given; if not, give error (mandatory in ITERATION)
306-
try:
307-
max_iter = iter_opts['max_iter']
308-
except KeyError:
309-
raise AssertionError('MAX_ITER has not been given in ITERATION')
310-
311-
elif 'mda' in self.ensemble.keys_da:
312-
iter_opts = dict(self.ensemble.keys_da['mda'])
313-
# Check if 'tot_assim_steps' has been given; if not, raise error (mandatory in MDA)
314-
try:
315-
max_iter = iter_opts['tot_assim_steps']
316-
except KeyError:
317-
raise AssertionError('TOT_ASSIM_STEPS has not been given in MDA!')
318-
319-
else:
320-
max_iter = 1
321-
# Return max. iter
322-
return max_iter
323-
324285
def _save_iteration_information(self):
325286
"""
326287
More general method for saving all relevant information from a analysis/forecast step. Note that this is

0 commit comments

Comments
 (0)