Skip to content

Commit 16abb21

Browse files
author
Feras A Saad
committed
Merge branch '20171121-fsaad-bayeslite-refactoring-fixes'
2 parents dff8a59 + 3275cf8 commit 16abb21

File tree

4 files changed

+6
-121
lines changed

4 files changed

+6
-121
lines changed

src/crosscat/state.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ def logpdf(self, rowid, targets, constraints=None, inputs=None,
406406

407407
def simulate(self, rowid, targets, constraints=None, inputs=None,
408408
N=None, accuracy=None):
409-
assert isinstance(targets, list)
409+
assert isinstance(targets, (list, tuple))
410410
assert inputs is None or isinstance(inputs, dict)
411411
self._validate_cgpm_query(rowid, targets, constraints)
412412
if not self._composite:
@@ -448,7 +448,7 @@ def _validate_cgpm_query(self, rowid, targets, constraints):
448448
# Is the rowid fresh?
449449
fresh = self.hypothetical(rowid)
450450
# Is the query simulate or logpdf?
451-
simulate = isinstance(targets, list)
451+
simulate = isinstance(targets, (list, tuple))
452452
# Disallow duplicated target cols.
453453
if simulate and len(set(targets)) != len(targets):
454454
raise ValueError('Columns in targets must be unique.')

src/mixtures/view.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ def _bulk_incorporate(self, dim):
486486

487487
def _validate_cgpm_query(self, rowid, targets, constraints):
488488
# Is the query simulate or logpdf?
489-
simulate = isinstance(targets, list)
489+
simulate = isinstance(targets, (list, tuple))
490490
# Disallow duplicated target cols.
491491
if simulate and len(set(targets)) != len(targets):
492492
raise ValueError('Columns in targets must be unique.')

src/utils/validation.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,10 @@ def validate_crp_constrained_input(N, Cd, Ci, Rd, Ri):
106106
def partition_query_evidence(Z, query, evidence):
107107
"""Returns queries[k], evidences[k] are queries, evidences for cluster k."""
108108
evidences = partition_dict(Z, evidence) if evidence is not None else dict()
109-
if isinstance(query, list):
110-
queries = partition_list(Z, query)
111-
else:
109+
if isinstance(query, dict):
112110
queries = partition_dict(Z, query)
111+
else:
112+
queries = partition_list(Z, query)
113113
return queries, evidences
114114

115115
def partition_list(Z, L):

tests/test_lovecat.py

-115
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,6 @@
3131

3232
import numpy as np
3333

34-
import bayeslite
35-
36-
from bayeslite.read_csv import bayesdb_read_csv
37-
38-
from crosscat.LocalEngine import LocalEngine
39-
4034
from cgpm.crosscat import lovecat
4135
from cgpm.crosscat.engine import Engine
4236
from cgpm.crosscat.state import State
@@ -45,15 +39,6 @@
4539
from cgpm.utils import test as tu
4640

4741

48-
def nullify(bdb, table, null):
49-
from bayeslite import bql_quote_name
50-
qt = bql_quote_name(table)
51-
for v in (r[1] for r in bdb.sql_execute('PRAGMA table_info(%s)' % (qt,))):
52-
qv = bql_quote_name(v)
53-
bdb.sql_execute(
54-
'UPDATE %s SET %s = NULL WHERE %s = ?' % (qt, qv, qv),
55-
(null,))
56-
5742
# -- Global variables shared by all module functions.
5843
rng = gu.gen_rng(2)
5944

@@ -100,43 +85,6 @@ def generate_dataset_2():
10085
return D
10186

10287

103-
# -------- Create a bdb instance with crosscat -------- #
104-
@contextlib.contextmanager
105-
def generate_bdb(T):
106-
with bayeslite.bayesdb_open(':memory:') as bdb:
107-
# Convert data into csv format and load it.
108-
T_header = str.join(',', ['c%d' % (i,) for i in range(T.shape[1])])
109-
T_data = str.join('\n', [str.join(',', map(str, row)) for row in T])
110-
f = StringIO.StringIO('%s\n%s' % (T_header, T_data))
111-
bayesdb_read_csv(bdb, 'data', f, header=True, create=True)
112-
nullify(bdb, 'data', 'nan')
113-
114-
# Create a population, ignoring column 1.
115-
bdb.execute('''
116-
CREATE POPULATION data_p FOR data WITH SCHEMA(
117-
IGNORE c1;
118-
MODEL c0, c2, c4, c6, c7 AS NUMERICAL;
119-
MODEL c3, c5 AS CATEGORICAL);
120-
''')
121-
122-
# Create a CrossCat metamodel.
123-
bdb.execute('''
124-
CREATE METAMODEL data_m FOR data_p USING crosscat(
125-
c0 NUMERICAL,
126-
c2 NUMERICAL,
127-
c4 NUMERICAL,
128-
c6 NUMERICAL,
129-
c7 NUMERICAL,
130-
131-
c3 CATEGORICAL,
132-
c5 CATEGORICAL);
133-
''')
134-
135-
bdb.execute('INITIALIZE 1 MODEL FOR data_m;')
136-
bdb.execute('ANALYZE data_m FOR 2 ITERATION WAIT;')
137-
yield bdb
138-
139-
14088
# -------- Create a cgpm.state crosscat instance -------- #
14189
def generate_state(T):
14290
# Remember that c1 is ignored.
@@ -156,69 +104,6 @@ def generate_state(T):
156104
return state
157105

158106

159-
def test_cgpm_lovecat_integration():
160-
"""A mix of unit and integration testing for lovecat analysis."""
161-
162-
T = generate_dataset()
163-
164-
with generate_bdb(T) as bdb:
165-
166-
# Retrieve the CrossCat metamodel instance.
167-
metamodel = bdb.metamodels['crosscat']
168-
169-
# Retrieve the cgpm.state
170-
state = generate_state(T)
171-
172-
# Assert that M_c_prime agrees with CrossCat M_c.
173-
M_c_prime = lovecat._crosscat_M_c(state)
174-
M_c = metamodel._crosscat_metadata(bdb, 1)
175-
176-
assert M_c['name_to_idx'] == M_c_prime['name_to_idx']
177-
assert M_c['idx_to_name'] == M_c_prime['idx_to_name']
178-
assert M_c['column_metadata'] == M_c_prime['column_metadata']
179-
180-
# Check that the converted datasets match.
181-
bdb_data = metamodel._crosscat_data(bdb, 1, M_c)
182-
cgpm_data = lovecat._crosscat_T(state, M_c_prime)
183-
assert np.all(np.isclose(bdb_data, cgpm_data, atol=1e-2, equal_nan=True))
184-
185-
# X_L and X_D from the CrossCat state. Not sure what tests to write
186-
# that acccess theta['X_L'] and theta['X_D'] directly.
187-
theta = metamodel._crosscat_theta(bdb, 1, 0)
188-
189-
# Retrieve X_D and X_L from the cgpm.state, and check they can be used
190-
# as arguments to LocalEngine.analyze.
191-
X_D = lovecat._crosscat_X_D(state, M_c_prime)
192-
X_L = lovecat._crosscat_X_L(state, M_c_prime, X_D)
193-
194-
LE = LocalEngine(seed=4)
195-
start = time.time()
196-
X_L_new, X_D_new = LE.analyze(
197-
M_c_prime, lovecat._crosscat_T(state, M_c_prime),
198-
X_L, X_D, 1, max_time=20, n_steps=100000000,
199-
progress=lovecat._progress)
200-
assert np.allclose(time.time() - start, 20, atol=2)
201-
202-
# This function call updates the cgpm.state internals to
203-
# match X_L_new, X_D_new. Check it does not destory the cgpm.state and
204-
# we can still run transitions.
205-
lovecat._update_state(state, M_c, X_L_new, X_D_new)
206-
state.transition(S=5)
207-
208-
# Invoke a lovecat transition directly through the cgpm.state,
209-
# for 10000 iters with a 5 second timeout.
210-
start = time.time()
211-
state.transition_lovecat(S=7, N=100000)
212-
# Give an extra second for function call overhead.
213-
assert 7. <= time.time() - start <= 8.
214-
215-
# Now invoke by iterations only.
216-
state.transition_lovecat(N=7, progress=False)
217-
218-
# Make sure we can now run regular cgpm.state transitions again.
219-
state.transition(S=5)
220-
221-
222107
def test_lovecat_transition_columns():
223108
"""Test transition_lovecat targeting specific rows and columns."""
224109
D = generate_dataset_2()

0 commit comments

Comments
 (0)