Skip to content

Commit 93331de

Browse files
author
Feras A Saad
committed
Fix #240, default implementation of observers.
Requries user who overrides one observer to override them all. Consider using a dictionary for user to override only some observers.
1 parent 2fe0e9f commit 93331de

File tree

2 files changed

+63
-14
lines changed

2 files changed

+63
-14
lines changed

src/venturescript/vscgpm.py

+27-6
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
import itertools
1718
import base64
1819
import copy
1920
import math
@@ -23,6 +24,8 @@
2324

2425
import venture.shortcuts as vs
2526

27+
from venture.exception import VentureException
28+
2629
from cgpm.cgpm import CGpm
2730
from cgpm.utils import config as cu
2831
from cgpm.utils import general as gu
@@ -71,7 +74,9 @@ def __init__(self, outputs, inputs, rng=None, sp=None, **kwargs):
7174
raise ValueError('source.inputs list disagrees with inputs.')
7275
self.inputs = inputs
7376
# Check overriden observers.
74-
if len(self.outputs) != self.ripl.evaluate('(size observers)'):
77+
num_observers = self._get_num_observers()
78+
self.obs_override = num_observers is not None
79+
if self.obs_override and len(self.outputs) != num_observers:
7580
raise ValueError('source.observers list disagrees with outputs.')
7681
# XXX Eliminate this nested defaultdict
7782
# Inputs and labels for incorporate/unincorporate.
@@ -174,12 +179,20 @@ def _predict_cell(self, rowid, target, inputs, label):
174179
'((lookup outputs %i) %s)' % (i, sp_args), label=label)
175180

176181
def _observe_cell(self, rowid, query, value, inputs):
182+
output_id = self.outputs.index(query)
177183
inputs_list = [inputs[i] for i in self.inputs]
178-
label = '\''+self._gen_label()
179-
sp_args = str.join(' ', map(str, [rowid] + inputs_list + [value, label]))
180-
i = self.outputs.index(query)
181-
self.ripl.evaluate('((lookup observers %i) %s)' % (i, sp_args))
182-
self.obs[rowid]['labels'][query] = label[1:]
184+
label = self._gen_label()
185+
if self.obs_override:
186+
qlabel = '(quote %s)' % (label,)
187+
sp_args = ' '.join(map(str,
188+
itertools.chain([rowid], inputs_list, [value, qlabel])))
189+
self.ripl.evaluate('((lookup observers %i) %s)'
190+
% (output_id, sp_args))
191+
else:
192+
sp_args = ' '.join(map(str, itertools.chain([rowid], inputs_list)))
193+
self.ripl.observe('((lookup outputs %i) %s)'
194+
% (output_id, sp_args), value, label=label)
195+
self.obs[rowid]['labels'][query] = label
183196

184197
def _forget_cell(self, rowid, query):
185198
if query not in self.obs[rowid]['labels']:
@@ -255,6 +268,14 @@ def _check_matched_inputs(self, rowid, inputs):
255268
raise ValueError('Given inputs contradicts dataset: %d, %s, %s' %
256269
(rowid, inputs, self.obs[rowid]['inputs']))
257270

271+
def _get_num_observers(self):
272+
# Return the length of the "observers" list defined by the client, or
273+
# None if the client did not override the observers.
274+
try:
275+
return self.ripl.evaluate('(size observers)')
276+
except VentureException:
277+
return None
278+
258279
@staticmethod
259280
def _obs_to_json(obs):
260281
def convert_key_int_to_str(d):

tests/test_vscgpm.py

+36-8
Original file line numberDiff line numberDiff line change
@@ -60,16 +60,13 @@
6060
(lambda (rowid w value label)
6161
(observe (simulate_y ,rowid ,w) value ,label))]
6262
63-
[define observers (list observe_m
64-
observe_y)]
65-
6663
[define inputs (list 'w)]
6764
6865
[define transition
6966
(lambda (N)
7067
(mh default one N))]
71-
"""
7268
69+
"""
7370

7471
source_concrete = """
7572
define make_cgpm = () -> {
@@ -100,19 +97,42 @@
10097
$label: observe simulate_y($rowid, $w) = value;
10198
};
10299
103-
define observers = [observe_m, observe_y];
104-
105100
define inputs = ["w"];
106101
107102
define transition = (N) -> {
108103
mh(default, one, N)
109104
};
105+
110106
"""
111107

108+
# Define source with client overriding observers.
109+
source_abstract_observers_good = source_abstract + \
110+
'[define observers (list observe_m observe_y)]\n'
111+
source_abstract_observers_bad = source_abstract + \
112+
'[define observers (list observe_m observe_y 2)]\n'
113+
114+
source_concrete_observers_good = source_concrete + \
115+
'define observers = [observe_m, observe_y];\n'
116+
source_concrete_observers_bad = source_concrete + \
117+
'define observers = [observe_m, observe_y, 2];\n'
118+
119+
# Define test cases.
112120
Case = namedtuple('Case', ['source', 'mode'])
113121
cases = [
114-
Case(source_abstract, 'church_prime'),
115-
Case(source_concrete, 'venture_script'),
122+
Case(source_abstract, 'church_prime'),
123+
Case(source_concrete, 'venture_script'),
124+
Case(source_abstract_observers_good, 'church_prime'),
125+
Case(source_concrete_observers_good, 'venture_script'),
126+
]
127+
128+
CaseObs = namedtuple('Case', ['source', 'obsok', 'mode'])
129+
casesObs = [
130+
CaseObs(source_abstract, True, 'church_prime'),
131+
CaseObs(source_concrete, True, 'venture_script'),
132+
CaseObs(source_abstract_observers_good, True, 'church_prime'),
133+
CaseObs(source_concrete_observers_good, True, 'venture_script'),
134+
CaseObs(source_abstract_observers_bad, False, 'church_prime'),
135+
CaseObs(source_concrete_observers_bad, False, 'venture_script'),
116136
]
117137

118138
@pytest.mark.parametrize('case', cases)
@@ -133,6 +153,14 @@ def test_wrong_inputs(case):
133153
with pytest.raises(ValueError):
134154
VsCGpm(outputs=[1,2], inputs=[3,4], source=case.source, mode=case.mode)
135155

156+
@pytest.mark.parametrize('case', casesObs)
157+
def test_wrong_observers(case):
158+
try:
159+
VsCGpm(outputs=[0,1], inputs=[2], source=case.source, mode=case.mode)
160+
assert case.obsok
161+
except ValueError:
162+
assert not case.obsok
163+
136164
@pytest.mark.parametrize('case', cases)
137165
def test_incorporate_unincorporate(case):
138166
cgpm = VsCGpm(outputs=[0,1], inputs=[3], source=case.source, mode=case.mode)

0 commit comments

Comments
 (0)