Skip to content

Commit 48385ba

Browse files
author
Feras A Saad
committed
Fix #187, eliminate nested defaultdict from vscgpm.
1 parent 71bb54f commit 48385ba

File tree

2 files changed

+29
-40
lines changed

2 files changed

+29
-40
lines changed

src/venturescript/vscgpm.py

+28-39
Original file line numberDiff line numberDiff line change
@@ -74,31 +74,32 @@ def __init__(self, outputs, inputs, rng=None, sp=None, **kwargs):
7474
raise ValueError('source.inputs list disagrees with inputs.')
7575
self.inputs = inputs
7676
self.input_mapping = self._get_input_mapping(self.inputs)
77-
# Check overriden observers.
77+
# Check custom observers.
7878
num_observers = self._get_num_observers()
79-
self.obs_override = num_observers is not None
80-
if self.obs_override and len(self.outputs) != num_observers:
79+
self.observe_custom = num_observers is not None
80+
if self.observe_custom and len(self.outputs) != num_observers:
8181
raise ValueError('source.observers list disagrees with outputs.')
82-
# XXX Eliminate this nested defaultdict
83-
# Inputs and labels for incorporate/unincorporate.
84-
self.obs = defaultdict(lambda: defaultdict(dict))
82+
# Entry labels[rowid][query] is label used to observe output cell.
83+
self.labels = dict()
8584

8685
def incorporate(self, rowid, observation, inputs=None):
8786
inputs2 = self._validate_incorporate(rowid, observation, inputs)
87+
if rowid not in self.labels:
88+
self.labels[rowid] = dict()
8889
for i, value in inputs2.iteritems():
8990
self._observe_input_cell(rowid, i, value)
9091
for t, value in observation.iteritems():
9192
self._observe_output_cell(rowid, t, value)
9293

9394
def unincorporate(self, rowid):
94-
if rowid not in self.obs:
95+
if rowid not in self.labels:
9596
raise ValueError('Never incorporated: %d' % rowid)
9697
for q in self.outputs:
9798
self._forget_output_cell(rowid, q)
9899
for i in self.inputs:
99100
self._forget_input_cell(rowid, i)
100-
assert len(self.obs[rowid]['labels']) == 0
101-
del self.obs[rowid]
101+
assert len(self.labels[rowid]) == 0
102+
del self.labels[rowid]
102103

103104
def logpdf(self, rowid, targets, constraints=None, inputs=None):
104105
return 0
@@ -147,7 +148,7 @@ def to_metadata(self):
147148
metadata['mode'] = self.mode
148149
metadata['plugins'] = self.plugins
149150
# Save the observations. We need to convert integer keys to strings.
150-
metadata['obs'] = VsCGpm._obs_to_json(copy.deepcopy(self.obs))
151+
metadata['labels'] = VsCGpm.convert_key_int_to_str(self.labels)
151152
metadata['binary'] = base64.b64encode(self.ripl.saves())
152153
metadata['factory'] = ('cgpm.venturescript.vscgpm', 'VsCGpm')
153154
return metadata
@@ -167,11 +168,9 @@ def from_metadata(cls, metadata, rng=None):
167168
rng=rng,
168169
)
169170
# Restore the observations. We need to convert string keys to integers.
170-
# XXX Eliminate this terrible defaultdict hack. See Github #187.
171-
obs_converted = VsCGpm._obs_from_json(metadata['obs'])
172-
cgpm.obs = defaultdict(lambda: defaultdict(dict))
173-
for key, value in obs_converted.iteritems():
174-
cgpm.obs[key] = defaultdict(dict, value)
171+
labels = VsCGpm.convert_key_str_to_int(metadata['labels'])
172+
for rowid, mapping in labels.iteritems():
173+
cgpm.labels[rowid] = mapping
175174
return cgpm
176175

177176
# --------------------------------------------------------------------------
@@ -187,14 +186,14 @@ def _observe_output_cell(self, rowid, query, value):
187186
output_idx = self.outputs.index(query)
188187
label = self._gen_label()
189188
sp_rowid = '(atom %d)' % (rowid,)
190-
if not self.obs_override:
189+
if not self.observe_custom:
191190
self.ripl.observe('((lookup outputs %i) %s)'
192191
% (output_idx, sp_rowid), value, label=label)
193192
else:
194193
obs_args = '%s %s (quote %s)' % (sp_rowid, value, label)
195194
self.ripl.evaluate('((lookup observers %i) %s)'
196195
% (output_idx, obs_args))
197-
self.obs[rowid]['labels'][query] = label
196+
self.labels[rowid][query] = label
198197

199198
def _observe_input_cell(self, rowid, idx, value):
200199
input_name = self.input_mapping[idx]
@@ -206,9 +205,9 @@ def _observe_input_cell(self, rowid, idx, value):
206205

207206
def _forget_output_cell(self, rowid, query):
208207
if self._is_observed_output_cell(rowid, query):
209-
label = self.obs[rowid]['labels'][query]
208+
label = self.labels[rowid][query]
210209
self.ripl.forget(label)
211-
del self.obs[rowid]['labels'][query]
210+
del self.labels[rowid][query]
212211

213212
def _forget_input_cell(self, rowid, idx):
214213
if self._is_observed_input_cell(rowid, idx):
@@ -222,7 +221,7 @@ def _forget_input_cell(self, rowid, idx):
222221
self.ripl.forget(input_cell_name)
223222

224223
def _is_observed_output_cell(self, rowid, query):
225-
return query in self.obs[rowid]['labels']
224+
return rowid in self.labels and query in self.labels[rowid]
226225

227226
def _is_observed_input_cell(self, rowid, idx):
228227
input_name = self.input_mapping[idx]
@@ -257,8 +256,8 @@ def _validate_incorporate(self, rowid, observation, inputs=None):
257256
raise ValueError('Nan inputs: %s' % inputs)
258257
if any(math.isnan(observation[i]) for i in observation):
259258
raise ValueError('Nan observation: %s' % (observation,))
260-
if rowid in self.obs \
261-
and any(q in self.obs[rowid]['labels'] for q in observation):
259+
if rowid in self.labels \
260+
and any(q in self.labels[rowid] for q in observation):
262261
raise ValueError('Observation exists: %d %s' % (rowid, observation))
263262
return self._check_input_args(rowid, inputs)
264263

@@ -307,31 +306,21 @@ def _check_input_args(self, rowid, inputs):
307306
return {i : inputs[i] for i in inputs if i not in inputs_obs}
308307

309308
def _check_constraints_args(self, rowid, constraints):
310-
constraints_obs = [q for q in constraints if rowid in self.obs and
309+
constraints_obs = [q for q in constraints if rowid in self.labels and
311310
self._is_observed_output_cell(rowid, q)]
312311
if constraints_obs:
313312
raise ValueError('Constrained observations exists: %d, %s, %s'
314313
% (rowid, constraints, constraints_obs))
315314

316315
@staticmethod
317-
def _obs_to_json(obs):
318-
def convert_key_int_to_str(d):
319-
assert all(isinstance(c, int) for c in d)
320-
return {str(c): v for c, v in d.iteritems()}
321-
obs2 = convert_key_int_to_str(obs)
322-
for r in obs2:
323-
obs2[r]['labels'] = convert_key_int_to_str(obs2[r]['labels'])
324-
return obs2
316+
def convert_key_int_to_str(d):
317+
assert all(isinstance(c, int) for c in d)
318+
return {str(c): v for c, v in d.iteritems()}
325319

326320
@staticmethod
327-
def _obs_from_json(obs):
328-
def convert_key_str_to_int(d):
329-
assert all(isinstance(c, (str, unicode)) for c in d)
330-
return {int(c): v for c, v in d.iteritems()}
331-
obs2 = convert_key_str_to_int(obs)
332-
for r in obs2:
333-
obs2[r]['labels'] = convert_key_str_to_int(obs2[r]['labels'])
334-
return obs2
321+
def convert_key_str_to_int(d):
322+
assert all(isinstance(c, (str, unicode)) for c in d)
323+
return {int(c): v for c, v in d.iteritems()}
335324

336325
@staticmethod
337326
def _load_helpers(ripl):

tests/test_vscgpm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def test_serialize(case):
249249
assert cgpm.outputs == cgpm_test.outputs
250250
assert cgpm.inputs == cgpm_test.inputs
251251
assert cgpm.source == cgpm_test.source
252-
assert cgpm.obs == cgpm_test.obs
252+
assert cgpm.labels == cgpm_test.labels
253253

254254
sample = cgpm_test.simulate(0, [0,1])
255255
assert sample[0] == 1

0 commit comments

Comments
 (0)