|
14 | 14 | # See the License for the specific language governing permissions and
|
15 | 15 | # limitations under the License.
|
16 | 16 |
|
| 17 | +import itertools |
17 | 18 | import base64
|
18 | 19 | import copy
|
19 | 20 | import math
|
|
23 | 24 |
|
24 | 25 | import venture.shortcuts as vs
|
25 | 26 |
|
| 27 | +from venture.exception import VentureException |
| 28 | + |
26 | 29 | from cgpm.cgpm import CGpm
|
27 | 30 | from cgpm.utils import config as cu
|
28 | 31 | from cgpm.utils import general as gu
|
@@ -71,7 +74,9 @@ def __init__(self, outputs, inputs, rng=None, sp=None, **kwargs):
|
71 | 74 | raise ValueError('source.inputs list disagrees with inputs.')
|
72 | 75 | self.inputs = inputs
|
73 | 76 | # Check overriden observers.
|
74 |
| - if len(self.outputs) != self.ripl.evaluate('(size observers)'): |
| 77 | + num_observers = self._get_num_observers() |
| 78 | + self.obs_override = num_observers is not None |
| 79 | + if self.obs_override and len(self.outputs) != num_observers: |
75 | 80 | raise ValueError('source.observers list disagrees with outputs.')
|
76 | 81 | # XXX Eliminate this nested defaultdict
|
77 | 82 | # Inputs and labels for incorporate/unincorporate.
|
@@ -174,12 +179,20 @@ def _predict_cell(self, rowid, target, inputs, label):
|
174 | 179 | '((lookup outputs %i) %s)' % (i, sp_args), label=label)
|
175 | 180 |
|
176 | 181 | def _observe_cell(self, rowid, query, value, inputs):
|
| 182 | + output_id = self.outputs.index(query) |
177 | 183 | inputs_list = [inputs[i] for i in self.inputs]
|
178 |
| - label = '\''+self._gen_label() |
179 |
| - sp_args = str.join(' ', map(str, [rowid] + inputs_list + [value, label])) |
180 |
| - i = self.outputs.index(query) |
181 |
| - self.ripl.evaluate('((lookup observers %i) %s)' % (i, sp_args)) |
182 |
| - self.obs[rowid]['labels'][query] = label[1:] |
| 184 | + label = self._gen_label() |
| 185 | + if self.obs_override: |
| 186 | + qlabel = '(quote %s)' % (label,) |
| 187 | + sp_args = ' '.join(map(str, |
| 188 | + itertools.chain([rowid], inputs_list, [value, qlabel]))) |
| 189 | + self.ripl.evaluate('((lookup observers %i) %s)' |
| 190 | + % (output_id, sp_args)) |
| 191 | + else: |
| 192 | + sp_args = ' '.join(map(str, itertools.chain([rowid], inputs_list))) |
| 193 | + self.ripl.observe('((lookup outputs %i) %s)' |
| 194 | + % (output_id, sp_args), value, label=label) |
| 195 | + self.obs[rowid]['labels'][query] = label |
183 | 196 |
|
184 | 197 | def _forget_cell(self, rowid, query):
|
185 | 198 | if query not in self.obs[rowid]['labels']:
|
@@ -255,6 +268,14 @@ def _check_matched_inputs(self, rowid, inputs):
|
255 | 268 | raise ValueError('Given inputs contradicts dataset: %d, %s, %s' %
|
256 | 269 | (rowid, inputs, self.obs[rowid]['inputs']))
|
257 | 270 |
|
| 271 | + def _get_num_observers(self): |
| 272 | + # Return the length of the "observers" list defined by the client, or |
| 273 | + # None if the client did not override the observers. |
| 274 | + try: |
| 275 | + return self.ripl.evaluate('(size observers)') |
| 276 | + except VentureException: |
| 277 | + return None |
| 278 | + |
258 | 279 | @staticmethod
|
259 | 280 | def _obs_to_json(obs):
|
260 | 281 | def convert_key_int_to_str(d):
|
|
0 commit comments