Skip to content

Commit 71bb54f

Browse files
author
Feras A Saad
committed
Fixes test_vscgpm.
1 parent 58ba73f commit 71bb54f

File tree

1 file changed

+36
-31
lines changed

1 file changed

+36
-31
lines changed

tests/test_vscgpm.py

+36-31
Original file line numberDiff line numberDiff line change
@@ -169,51 +169,48 @@ def test_wrong_observers(case):
169169
def test_incorporate_unincorporate(case):
170170
cgpm = VsCGpm(outputs=[0,1], inputs=[3], source=case.source, mode=case.mode)
171171

172-
OBS = [[1.2, .2], [1, 4]]
173-
EV = [0, 2]
174-
175-
rowid = 0
172+
observations = [[1.2, .2], [1, 4]]
173+
inputs = [0, 2]
174+
rowid0 = 0
175+
rowid1 = 1
176176

177177
# Missing input will raise a lookup error in Venture.
178178
with pytest.raises(VentureException):
179-
cgpm.incorporate(rowid, {1:OBS[rowid][1]}, {})
179+
cgpm.incorporate(rowid0 , {1: observations[rowid0][1]}, {})
180180
# No query.
181181
with pytest.raises(ValueError):
182-
cgpm.incorporate(rowid, {}, {3:EV[rowid]})
183-
184-
cgpm.incorporate(rowid, {0:OBS[rowid][0]}, {3:EV[rowid]})
185-
182+
cgpm.incorporate(rowid0, {}, {3: inputs[rowid0]})
183+
cgpm.incorporate(rowid0, {0: observations[rowid0][0]}, {3: inputs[rowid0]})
186184
# Duplicate observation.
187185
with pytest.raises(ValueError):
188-
cgpm.incorporate(rowid, {0:OBS[rowid][0]})
186+
cgpm.incorporate(rowid0, {0: observations[rowid0][0]})
189187
# Incompatible evidence.
190188
with pytest.raises(ValueError):
191-
cgpm.incorporate(rowid, {1:OBS[rowid][1]}, {3:EV[rowid]+1})
192-
# Compatible evidence.
193-
cgpm.incorporate(rowid, {1:OBS[rowid][1]}, {3:EV[rowid]})
189+
cgpm.incorporate(rowid0, {1: observations[rowid0][1]},
190+
{3: inputs[rowid0]+1})
191+
192+
cgpm.incorporate(rowid0, {1: observations[rowid0][1]}, {3: inputs[rowid0]})
194193

195-
rowid = 1
196-
cgpm.incorporate(rowid, {1:OBS[rowid][1]}, {3:EV[rowid]})
197-
# Optional evidence.
198-
cgpm.incorporate(rowid, {0:OBS[rowid][0]})
194+
cgpm.incorporate(rowid1, {0: observations[rowid1][0]})
195+
cgpm.incorporate(rowid1, {1: observations[rowid1][1]}, {3: inputs[rowid1]})
199196

200197
# Test observation stable after transition.
201198
def test_samples_match():
202-
# Check all samples match.
203-
sample = cgpm.simulate(0, [0,1])
204-
assert sample[0] == OBS[0][0]
205-
assert sample[1] == OBS[0][1]
206-
207-
sample = cgpm.simulate(1, [1])
208-
assert sample[1] == OBS[rowid][1]
199+
# Test rowid0.
200+
sample = cgpm.simulate(rowid0, [0,1])
201+
assert sample[0] == observations[rowid0][0]
202+
assert sample[1] == observations[rowid0][1]
203+
# Test rowid1.
204+
sample = cgpm.simulate(rowid1, [1])
205+
assert sample[1] == observations[rowid1][1]
209206
sample = cgpm.simulate(1, [0])
210-
assert sample[0] == OBS[rowid][0]
207+
assert sample[0] == observations[rowid1][0]
211208

212209
test_samples_match()
213210
cgpm.transition(N=10)
214211
test_samples_match()
215212

216-
# Test that simulating a hypothetical twice is different.
213+
# Test simulating hypothetical rowid twice gives different results.
217214
first = cgpm.simulate(-100, [0, 1], None, {3:4})
218215
second = cgpm.simulate(-100, [0, 1], None, {3:4})
219216
assert first != second
@@ -222,12 +219,12 @@ def test_samples_match():
222219
cgpm.unincorporate(1)
223220
cgpm.simulate(1, [0])
224221
with pytest.raises(VentureException):
225-
# Missing inputs required for output 1.
222+
# Missing inputs, w is required for output 1.
226223
cgpm.simulate(1, [1])
227224
cgpm.transition(N=10)
228-
sample = cgpm.simulate(1, [0,1], None, {3:EV[rowid]})
229-
assert not np.allclose(sample[0], OBS[rowid][0])
230-
assert not np.allclose(sample[1], OBS[rowid][1])
225+
sample = cgpm.simulate(1, [0,1], None, {3: inputs[rowid1]})
226+
assert not np.allclose(sample[0], observations[rowid1][0])
227+
assert not np.allclose(sample[1], observations[rowid1][1])
231228

232229

233230
@pytest.mark.parametrize('case', cases)
@@ -248,7 +245,6 @@ def test_serialize(case):
248245
# Load binary from JSON.
249246
cgpm3 = builder.from_metadata(json.loads(json.dumps(binary)))
250247

251-
print
252248
for cgpm_test in [cgpm2]:
253249
assert cgpm.outputs == cgpm_test.outputs
254250
assert cgpm.inputs == cgpm_test.inputs
@@ -262,7 +258,16 @@ def test_serialize(case):
262258
sample = cgpm_test.simulate(1, [1])
263259
assert sample[1] == 15
264260

261+
assert cgpm_test._get_input_cell_value(0, 3) == 0
262+
assert cgpm_test._get_input_cell_value(1, 3) == 10
263+
264+
assert cgpm_test._is_observed_output_cell(0, 0)
265+
assert cgpm_test._is_observed_output_cell(0, 1)
266+
assert cgpm_test._is_observed_output_cell(1, 1)
267+
assert not cgpm_test._is_observed_output_cell(1, 0)
268+
265269
cgpm_test.incorporate(1, {0:10})
270+
assert cgpm_test._is_observed_output_cell(1, 0)
266271

267272

268273
@pytest.mark.xfail(strict=True, reason='Github issue #215 (serialization).')

0 commit comments

Comments
 (0)