Skip to content

Commit d4a01a3

Browse files
author
Feras A Saad
committed
Fix #187, eliminate nested defaultdict from vscgpm.
1 parent 71bb54f commit d4a01a3

File tree

2 files changed

+28
-38
lines changed

2 files changed

+28
-38
lines changed

src/venturescript/vscgpm.py

+27-37
Original file line numberDiff line numberDiff line change
@@ -76,29 +76,30 @@ def __init__(self, outputs, inputs, rng=None, sp=None, **kwargs):
7676
self.input_mapping = self._get_input_mapping(self.inputs)
7777
# Check overriden observers.
7878
num_observers = self._get_num_observers()
79-
self.obs_override = num_observers is not None
80-
if self.obs_override and len(self.outputs) != num_observers:
79+
self.observe_custom = num_observers is not None
80+
if self.observe_custom and len(self.outputs) != num_observers:
8181
raise ValueError('source.observers list disagrees with outputs.')
82-
# XXX Eliminate this nested defaultdict
83-
# Inputs and labels for incorporate/unincorporate.
84-
self.obs = defaultdict(lambda: defaultdict(dict))
82+
# Entry labels[rowid][query] is label used to observe output cell.
83+
self.labels = dict()
8584

8685
def incorporate(self, rowid, observation, inputs=None):
8786
inputs2 = self._validate_incorporate(rowid, observation, inputs)
87+
if rowid not in self.labels:
88+
self.labels[rowid] = dict()
8889
for i, value in inputs2.iteritems():
8990
self._observe_input_cell(rowid, i, value)
9091
for t, value in observation.iteritems():
9192
self._observe_output_cell(rowid, t, value)
9293

9394
def unincorporate(self, rowid):
94-
if rowid not in self.obs:
95+
if rowid not in self.labels:
9596
raise ValueError('Never incorporated: %d' % rowid)
9697
for q in self.outputs:
9798
self._forget_output_cell(rowid, q)
9899
for i in self.inputs:
99100
self._forget_input_cell(rowid, i)
100-
assert len(self.obs[rowid]['labels']) == 0
101-
del self.obs[rowid]
101+
assert len(self.labels[rowid]) == 0
102+
del self.labels[rowid]
102103

103104
def logpdf(self, rowid, targets, constraints=None, inputs=None):
104105
return 0
@@ -147,7 +148,7 @@ def to_metadata(self):
147148
metadata['mode'] = self.mode
148149
metadata['plugins'] = self.plugins
149150
# Save the observations. We need to convert integer keys to strings.
150-
metadata['obs'] = VsCGpm._obs_to_json(copy.deepcopy(self.obs))
151+
metadata['labels'] = VsCGpm.convert_key_int_to_str(self.labels)
151152
metadata['binary'] = base64.b64encode(self.ripl.saves())
152153
metadata['factory'] = ('cgpm.venturescript.vscgpm', 'VsCGpm')
153154
return metadata
@@ -168,10 +169,9 @@ def from_metadata(cls, metadata, rng=None):
168169
)
169170
# Restore the observations. We need to convert string keys to integers.
170171
# XXX Eliminate this terrible defaultdict hack. See Github #187.
171-
obs_converted = VsCGpm._obs_from_json(metadata['obs'])
172-
cgpm.obs = defaultdict(lambda: defaultdict(dict))
173-
for key, value in obs_converted.iteritems():
174-
cgpm.obs[key] = defaultdict(dict, value)
172+
labels = VsCGpm.convert_key_str_to_int(metadata['labels'])
173+
for rowid, mapping in labels.iteritems():
174+
cgpm.labels[rowid] = mapping
175175
return cgpm
176176

177177
# --------------------------------------------------------------------------
@@ -187,14 +187,14 @@ def _observe_output_cell(self, rowid, query, value):
187187
output_idx = self.outputs.index(query)
188188
label = self._gen_label()
189189
sp_rowid = '(atom %d)' % (rowid,)
190-
if not self.obs_override:
190+
if not self.observe_custom:
191191
self.ripl.observe('((lookup outputs %i) %s)'
192192
% (output_idx, sp_rowid), value, label=label)
193193
else:
194194
obs_args = '%s %s (quote %s)' % (sp_rowid, value, label)
195195
self.ripl.evaluate('((lookup observers %i) %s)'
196196
% (output_idx, obs_args))
197-
self.obs[rowid]['labels'][query] = label
197+
self.labels[rowid][query] = label
198198

199199
def _observe_input_cell(self, rowid, idx, value):
200200
input_name = self.input_mapping[idx]
@@ -206,9 +206,9 @@ def _observe_input_cell(self, rowid, idx, value):
206206

207207
def _forget_output_cell(self, rowid, query):
208208
if self._is_observed_output_cell(rowid, query):
209-
label = self.obs[rowid]['labels'][query]
209+
label = self.labels[rowid][query]
210210
self.ripl.forget(label)
211-
del self.obs[rowid]['labels'][query]
211+
del self.labels[rowid][query]
212212

213213
def _forget_input_cell(self, rowid, idx):
214214
if self._is_observed_input_cell(rowid, idx):
@@ -222,7 +222,7 @@ def _forget_input_cell(self, rowid, idx):
222222
self.ripl.forget(input_cell_name)
223223

224224
def _is_observed_output_cell(self, rowid, query):
225-
return query in self.obs[rowid]['labels']
225+
return rowid in self.labels and query in self.labels[rowid]
226226

227227
def _is_observed_input_cell(self, rowid, idx):
228228
input_name = self.input_mapping[idx]
@@ -257,8 +257,8 @@ def _validate_incorporate(self, rowid, observation, inputs=None):
257257
raise ValueError('Nan inputs: %s' % inputs)
258258
if any(math.isnan(observation[i]) for i in observation):
259259
raise ValueError('Nan observation: %s' % (observation,))
260-
if rowid in self.obs \
261-
and any(q in self.obs[rowid]['labels'] for q in observation):
260+
if rowid in self.labels \
261+
and any(q in self.labels[rowid] for q in observation):
262262
raise ValueError('Observation exists: %d %s' % (rowid, observation))
263263
return self._check_input_args(rowid, inputs)
264264

@@ -307,31 +307,21 @@ def _check_input_args(self, rowid, inputs):
307307
return {i : inputs[i] for i in inputs if i not in inputs_obs}
308308

309309
def _check_constraints_args(self, rowid, constraints):
310-
constraints_obs = [q for q in constraints if rowid in self.obs and
310+
constraints_obs = [q for q in constraints if rowid in self.labels and
311311
self._is_observed_output_cell(rowid, q)]
312312
if constraints_obs:
313313
raise ValueError('Constrained observations exists: %d, %s, %s'
314314
% (rowid, constraints, constraints_obs))
315315

316316
@staticmethod
317-
def _obs_to_json(obs):
318-
def convert_key_int_to_str(d):
319-
assert all(isinstance(c, int) for c in d)
320-
return {str(c): v for c, v in d.iteritems()}
321-
obs2 = convert_key_int_to_str(obs)
322-
for r in obs2:
323-
obs2[r]['labels'] = convert_key_int_to_str(obs2[r]['labels'])
324-
return obs2
317+
def convert_key_int_to_str(d):
318+
assert all(isinstance(c, int) for c in d)
319+
return {str(c): v for c, v in d.iteritems()}
325320

326321
@staticmethod
327-
def _obs_from_json(obs):
328-
def convert_key_str_to_int(d):
329-
assert all(isinstance(c, (str, unicode)) for c in d)
330-
return {int(c): v for c, v in d.iteritems()}
331-
obs2 = convert_key_str_to_int(obs)
332-
for r in obs2:
333-
obs2[r]['labels'] = convert_key_str_to_int(obs2[r]['labels'])
334-
return obs2
322+
def convert_key_str_to_int(d):
323+
assert all(isinstance(c, (str, unicode)) for c in d)
324+
return {int(c): v for c, v in d.iteritems()}
335325

336326
@staticmethod
337327
def _load_helpers(ripl):

tests/test_vscgpm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def test_serialize(case):
249249
assert cgpm.outputs == cgpm_test.outputs
250250
assert cgpm.inputs == cgpm_test.inputs
251251
assert cgpm.source == cgpm_test.source
252-
assert cgpm.obs == cgpm_test.obs
252+
assert cgpm.labels == cgpm_test.labels
253253

254254
sample = cgpm_test.simulate(0, [0,1])
255255
assert sample[0] == 1

0 commit comments

Comments
 (0)