@@ -74,31 +74,32 @@ def __init__(self, outputs, inputs, rng=None, sp=None, **kwargs):
74
74
raise ValueError ('source.inputs list disagrees with inputs.' )
75
75
self .inputs = inputs
76
76
self .input_mapping = self ._get_input_mapping (self .inputs )
77
- # Check overriden observers.
77
+ # Check custom observers.
78
78
num_observers = self ._get_num_observers ()
79
- self .obs_override = num_observers is not None
80
- if self .obs_override and len (self .outputs ) != num_observers :
79
+ self .observe_custom = num_observers is not None
80
+ if self .observe_custom and len (self .outputs ) != num_observers :
81
81
raise ValueError ('source.observers list disagrees with outputs.' )
82
- # XXX Eliminate this nested defaultdict
83
- # Inputs and labels for incorporate/unincorporate.
84
- self .obs = defaultdict (lambda : defaultdict (dict ))
82
+ # Entry labels[rowid][query] is label used to observe output cell.
83
+ self .labels = dict ()
85
84
86
85
def incorporate (self , rowid , observation , inputs = None ):
87
86
inputs2 = self ._validate_incorporate (rowid , observation , inputs )
87
+ if rowid not in self .labels :
88
+ self .labels [rowid ] = dict ()
88
89
for i , value in inputs2 .iteritems ():
89
90
self ._observe_input_cell (rowid , i , value )
90
91
for t , value in observation .iteritems ():
91
92
self ._observe_output_cell (rowid , t , value )
92
93
93
94
def unincorporate (self , rowid ):
94
- if rowid not in self .obs :
95
+ if rowid not in self .labels :
95
96
raise ValueError ('Never incorporated: %d' % rowid )
96
97
for q in self .outputs :
97
98
self ._forget_output_cell (rowid , q )
98
99
for i in self .inputs :
99
100
self ._forget_input_cell (rowid , i )
100
- assert len (self .obs [rowid ][ 'labels' ]) == 0
101
- del self .obs [rowid ]
101
+ assert len (self .labels [rowid ]) == 0
102
+ del self .labels [rowid ]
102
103
103
104
def logpdf (self , rowid , targets , constraints = None , inputs = None ):
104
105
return 0
@@ -147,7 +148,7 @@ def to_metadata(self):
147
148
metadata ['mode' ] = self .mode
148
149
metadata ['plugins' ] = self .plugins
149
150
# Save the observations. We need to convert integer keys to strings.
150
- metadata ['obs ' ] = VsCGpm ._obs_to_json ( copy . deepcopy ( self .obs ) )
151
+ metadata ['labels ' ] = VsCGpm .convert_key_int_to_str ( self .labels )
151
152
metadata ['binary' ] = base64 .b64encode (self .ripl .saves ())
152
153
metadata ['factory' ] = ('cgpm.venturescript.vscgpm' , 'VsCGpm' )
153
154
return metadata
@@ -167,11 +168,9 @@ def from_metadata(cls, metadata, rng=None):
167
168
rng = rng ,
168
169
)
169
170
# Restore the observations. We need to convert string keys to integers.
170
- # XXX Eliminate this terrible defaultdict hack. See Github #187.
171
- obs_converted = VsCGpm ._obs_from_json (metadata ['obs' ])
172
- cgpm .obs = defaultdict (lambda : defaultdict (dict ))
173
- for key , value in obs_converted .iteritems ():
174
- cgpm .obs [key ] = defaultdict (dict , value )
171
+ labels = VsCGpm .convert_key_str_to_int (metadata ['labels' ])
172
+ for rowid , mapping in labels .iteritems ():
173
+ cgpm .labels [rowid ] = mapping
175
174
return cgpm
176
175
177
176
# --------------------------------------------------------------------------
@@ -187,14 +186,14 @@ def _observe_output_cell(self, rowid, query, value):
187
186
output_idx = self .outputs .index (query )
188
187
label = self ._gen_label ()
189
188
sp_rowid = '(atom %d)' % (rowid ,)
190
- if not self .obs_override :
189
+ if not self .observe_custom :
191
190
self .ripl .observe ('((lookup outputs %i) %s)'
192
191
% (output_idx , sp_rowid ), value , label = label )
193
192
else :
194
193
obs_args = '%s %s (quote %s)' % (sp_rowid , value , label )
195
194
self .ripl .evaluate ('((lookup observers %i) %s)'
196
195
% (output_idx , obs_args ))
197
- self .obs [rowid ][ 'labels' ][query ] = label
196
+ self .labels [rowid ][query ] = label
198
197
199
198
def _observe_input_cell (self , rowid , idx , value ):
200
199
input_name = self .input_mapping [idx ]
@@ -206,9 +205,9 @@ def _observe_input_cell(self, rowid, idx, value):
206
205
207
206
def _forget_output_cell (self , rowid , query ):
208
207
if self ._is_observed_output_cell (rowid , query ):
209
- label = self .obs [rowid ][ 'labels' ][query ]
208
+ label = self .labels [rowid ][query ]
210
209
self .ripl .forget (label )
211
- del self .obs [rowid ][ 'labels' ][query ]
210
+ del self .labels [rowid ][query ]
212
211
213
212
def _forget_input_cell (self , rowid , idx ):
214
213
if self ._is_observed_input_cell (rowid , idx ):
@@ -222,7 +221,7 @@ def _forget_input_cell(self, rowid, idx):
222
221
self .ripl .forget (input_cell_name )
223
222
224
223
def _is_observed_output_cell (self , rowid , query ):
225
- return query in self .obs [rowid ][ 'labels' ]
224
+ return rowid in self . labels and query in self .labels [rowid ]
226
225
227
226
def _is_observed_input_cell (self , rowid , idx ):
228
227
input_name = self .input_mapping [idx ]
@@ -257,8 +256,8 @@ def _validate_incorporate(self, rowid, observation, inputs=None):
257
256
raise ValueError ('Nan inputs: %s' % inputs )
258
257
if any (math .isnan (observation [i ]) for i in observation ):
259
258
raise ValueError ('Nan observation: %s' % (observation ,))
260
- if rowid in self .obs \
261
- and any (q in self .obs [rowid ][ 'labels' ] for q in observation ):
259
+ if rowid in self .labels \
260
+ and any (q in self .labels [rowid ] for q in observation ):
262
261
raise ValueError ('Observation exists: %d %s' % (rowid , observation ))
263
262
return self ._check_input_args (rowid , inputs )
264
263
@@ -307,31 +306,21 @@ def _check_input_args(self, rowid, inputs):
307
306
return {i : inputs [i ] for i in inputs if i not in inputs_obs }
308
307
309
308
def _check_constraints_args (self , rowid , constraints ):
310
- constraints_obs = [q for q in constraints if rowid in self .obs and
309
+ constraints_obs = [q for q in constraints if rowid in self .labels and
311
310
self ._is_observed_output_cell (rowid , q )]
312
311
if constraints_obs :
313
312
raise ValueError ('Constrained observations exists: %d, %s, %s'
314
313
% (rowid , constraints , constraints_obs ))
315
314
316
315
@staticmethod
317
- def _obs_to_json (obs ):
318
- def convert_key_int_to_str (d ):
319
- assert all (isinstance (c , int ) for c in d )
320
- return {str (c ): v for c , v in d .iteritems ()}
321
- obs2 = convert_key_int_to_str (obs )
322
- for r in obs2 :
323
- obs2 [r ]['labels' ] = convert_key_int_to_str (obs2 [r ]['labels' ])
324
- return obs2
316
+ def convert_key_int_to_str (d ):
317
+ assert all (isinstance (c , int ) for c in d )
318
+ return {str (c ): v for c , v in d .iteritems ()}
325
319
326
320
@staticmethod
327
- def _obs_from_json (obs ):
328
- def convert_key_str_to_int (d ):
329
- assert all (isinstance (c , (str , unicode )) for c in d )
330
- return {int (c ): v for c , v in d .iteritems ()}
331
- obs2 = convert_key_str_to_int (obs )
332
- for r in obs2 :
333
- obs2 [r ]['labels' ] = convert_key_str_to_int (obs2 [r ]['labels' ])
334
- return obs2
321
+ def convert_key_str_to_int (d ):
322
+ assert all (isinstance (c , (str , unicode )) for c in d )
323
+ return {int (c ): v for c , v in d .iteritems ()}
335
324
336
325
@staticmethod
337
326
def _load_helpers (ripl ):
0 commit comments