Skip to content

Commit 0d80bf2

Browse files
author
Feras A Saad
committed
Fix #585 and #586, initialize creates 1 model and remove check_initialized.
1 parent feb1e22 commit 0d80bf2

File tree

2 files changed

+3
-14
lines changed

2 files changed

+3
-14
lines changed

src/metamodels/loom_metamodel.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@ def initialize_models(self, bdb, generator_id, modelnos):
322322
(generator_id, num_models)
323323
VALUES (?, ?)
324324
''', (generator_id, len(modelnos) + num_existing))
325+
self.analyze_models(bdb, generator_id, iterations=1)
325326

326327
def _get_num_models(self, bdb, generator_id):
327328
cursor = bdb.sql_execute('''
@@ -493,7 +494,6 @@ def _get_cross_cat(self, bdb, generator_id, modelno):
493494

494495
def column_dependence_probability(self,
495496
bdb, generator_id, modelnos, colno0, colno1):
496-
self._check_loom_initialized(bdb, generator_id)
497497
if modelnos is None:
498498
modelnos = range(self._get_num_models(bdb, generator_id))
499499
if colno0 == colno1:
@@ -534,7 +534,6 @@ def _get_partition_id(self, bdb, generator_id, modelno, kind_id, rowid):
534534

535535
def column_mutual_information(self, bdb, generator_id, modelnos, colnos0,
536536
colnos1, constraints, numsamples):
537-
self._check_loom_initialized(bdb, generator_id)
538537
population_id = core.bayesdb_generator_population(bdb, generator_id)
539538
colnames0 = [str(core.bayesdb_variable_name(bdb, population_id, colno))
540539
for colno in colnos0]
@@ -553,7 +552,6 @@ def column_mutual_information(self, bdb, generator_id, modelnos, colnos0,
553552

554553
def row_similarity(self, bdb, generator_id, modelnos, rowid, target_rowid,
555554
colnos):
556-
self._check_loom_initialized(bdb, generator_id)
557555
if modelnos is None:
558556
modelnos = range(self._get_num_models(bdb, generator_id))
559557
assert len(colnos) == 1
@@ -607,7 +605,6 @@ def _reorder_row(self, bdb, generator_id, row, dense=True):
607605

608606
def predictive_relevance(self, bdb, generator_id, modelnos, rowid_target,
609607
rowid_queries, hypotheticals, colno):
610-
self._check_loom_initialized(bdb, generator_id)
611608
if len(hypotheticals) > 0:
612609
raise BQLError(bdb, 'Loom cannot handle hypothetical rows' \
613610
' because it is unable to insert rows into CrossCat')
@@ -630,7 +627,6 @@ def predictive_relevance(self, bdb, generator_id, modelnos, rowid_target,
630627

631628
def predict_confidence(self, bdb, generator_id, modelnos, rowid, colno,
632629
numsamples=None):
633-
self._check_loom_initialized(bdb, generator_id)
634630
if not numsamples:
635631
numsamples = 2
636632
assert numsamples > 0
@@ -666,7 +662,6 @@ def _is_categorical(stattype):
666662

667663
def simulate_joint(self, bdb, generator_id, modelnos, rowid, targets,
668664
constraints, num_samples=1, accuracy=None):
669-
self._check_loom_initialized(bdb, generator_id)
670665
if rowid != core.bayesdb_generator_fresh_row_id(bdb, generator_id):
671666
row_values_raw = core.bayesdb_generator_row_values(
672667
bdb, generator_id, rowid)
@@ -733,8 +728,6 @@ def simulate_joint(self, bdb, generator_id, modelnos, rowid, targets,
733728

734729
def logpdf_joint(self, bdb, generator_id, modelnos, rowid, targets,
735730
constraints):
736-
self._check_loom_initialized(bdb, generator_id)
737-
738731
population_id = core.bayesdb_generator_population(bdb, generator_id)
739732
ordered_column_labels = self._get_ordered_column_labels(
740733
bdb, generator_id)

tests/test_loom_metamodel.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,8 @@ def test_loom_complex_add_analyze_drop_sequence():
9999
# 3 and not 2 + 3 = 5.
100100
assert num_models == 3
101101

102-
# Check for a bql error if a query is run after initialization
103-
# but before analyze.
104-
with pytest.raises(BQLError):
105-
bdb.execute('estimate probability density of x = 50 from p')
106-
107-
bdb.execute('analyze g for 50 iterations wait')
102+
bdb.execute('analyze g for 10 iterations wait')
103+
bdb.execute('estimate probability density of x = 50 from p')
108104

109105
with pytest.raises(BQLError):
110106
bdb.execute('drop model 1 from g')

0 commit comments

Comments
 (0)