Skip to content

Commit 4ce4d19

Browse files
author
Rolf Johan Lorentzen
committed
Move computation of Am to full_update
1 parent d310b39 commit 4ce4d19

File tree

3 files changed

+32
-20
lines changed

3 files changed

+32
-20
lines changed

pipt/loop/ensemble.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -556,29 +556,12 @@ def _ext_obs(self):
556556
self.real_obs_data, self.scale_data = init_en.gen_real(self.obs_data_vector, self.cov_data, self.ne,
557557
return_chol=True)
558558

559-
def _ext_state(self):
559+
def _ext_scaling(self):
560560
# get vector of scaling
561561
self.state_scaling = at.calc_scaling(
562562
self.prior_state, self.list_states, self.prior_info)
563563

564-
delta_scaled_prior = self.state_scaling[:, None] * \
565-
np.dot(at.aug_state(self.prior_state, self.list_states), self.proj)
566-
567-
u_d, s_d, v_d = np.linalg.svd(delta_scaled_prior, full_matrices=False)
568-
569-
# remove the last singular value/vector. This is because numpy returns all ne values, while the last is actually
570-
# zero. This part is a good place to include eventual additional truncation.
571-
energy = 0
572-
trunc_index = len(s_d) - 1 # inititallize
573-
for c, elem in enumerate(s_d):
574-
energy += elem
575-
if energy / sum(s_d) >= self.trunc_energy:
576-
trunc_index = c # take the index where all energy is preserved
577-
break
578-
u_d, s_d, v_d = u_d[:, :trunc_index +
579-
1], s_d[:trunc_index + 1], v_d[:trunc_index + 1, :]
580-
self.Am = np.dot(u_d, np.eye(trunc_index+1) *
581-
((s_d**(-1))[:, None])) # notation from paper
564+
self.Am = None
582565

583566
def save_temp_state_assim(self, ind_save):
584567
"""

pipt/update_schemes/esmda.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(self, keys_da, keys_en, sim):
7070
self._ext_obs()
7171
self.real_obs_data_conv = deepcopy(self.real_obs_data)
7272
# Get state scaling and svd of scaled prior
73-
self._ext_state()
73+
self._ext_scaling()
7474
self.current_state = deepcopy(self.state)
7575
# Extract the inflation parameter from MDA keyword
7676
self.alpha = self._ext_inflation_param()

pipt/update_schemes/update_methods_ns/full_update.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,36 @@ class full_update():
1818
no localization is implemented for this method yet.
1919
"""
2020

21+
def ext_Am(self, *args, **kwargs):
22+
"""
23+
The class is initialized by calculating the required Am matrix.
24+
"""
25+
26+
delta_scaled_prior = self.state_scaling[:, None] * \
27+
np.dot(at.aug_state(self.prior_state, self.list_states), self.proj)
28+
29+
u_d, s_d, v_d = np.linalg.svd(delta_scaled_prior, full_matrices=False)
30+
31+
# remove the last singular value/vector. This is because numpy returns all ne values, while the last is actually
32+
# zero. This part is a good place to include eventual additional truncation.
33+
energy = 0
34+
trunc_index = len(s_d) - 1 # inititallize
35+
for c, elem in enumerate(s_d):
36+
energy += elem
37+
if energy / sum(s_d) >= self.trunc_energy:
38+
trunc_index = c # take the index where all energy is preserved
39+
break
40+
u_d, s_d, v_d = u_d[:, :trunc_index +
41+
1], s_d[:trunc_index + 1], v_d[:trunc_index + 1, :]
42+
self.Am = np.dot(u_d, np.eye(trunc_index + 1) *
43+
((s_d ** (-1))[:, None])) # notation from paper
44+
45+
2146
def update(self):
47+
48+
if self.Am is None:
49+
self.ext_Am() # do this only once
50+
2251
aug_state = at.aug_state(self.current_state, self.list_states)
2352
aug_prior_state = at.aug_state(self.prior_state, self.list_states)
2453

0 commit comments

Comments
 (0)