Skip to content

Commit 8901f8f

Browse files
some comments and change to how parameter variants are stored
Signed-off-by: Nikolaj Bjorner <[email protected]>
1 parent d3e2527 commit 8901f8f

File tree

2 files changed

+71
-66
lines changed

2 files changed

+71
-66
lines changed

src/smt/smt_parallel.cpp

Lines changed: 53 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ namespace smt {
8383
return r;
8484
}
8585

86-
unsigned parallel::param_generator::replay_proof_prefixes(vector<smt_params> candidate_param_states, unsigned max_conflicts_epsilon=200) {
86+
unsigned parallel::param_generator::replay_proof_prefixes(vector<param_values> const& candidate_param_states, unsigned max_conflicts_epsilon=200) {
8787
unsigned conflict_budget = m_max_prefix_conflicts + max_conflicts_epsilon;
8888
unsigned best_param_state_idx;
8989
double best_score;
@@ -95,11 +95,11 @@ namespace smt {
9595
double score = 0.0;
9696

9797
// apply the ith param state to probe_ctx
98-
smt_params params = candidate_param_states[i];
99-
params_ref p;
100-
params.updt_params(p);
98+
params_ref p = apply_param_values(candidate_param_states[i]);
10199
probe_ctx->updt_params(p);
102100

101+
// todo: m_recorded_cubes as a expr_ref_vector
102+
103103
for (auto const& clause : probe_ctx->m_recorded_clauses) {
104104
expr_ref_vector negated_lits(probe_ctx->m);
105105
for (literal lit : clause) {
@@ -111,17 +111,13 @@ namespace smt {
111111
}
112112

113113
// Replay the negated clause
114+
114115
lbool r = probe_ctx->check(negated_lits.size(), negated_lits.data());
115116

116-
::statistics st;
117-
probe_ctx->collect_statistics(st);
118-
unsigned conflicts = 0, decisions = 0, rlimit = 0;
117+
unsigned conflicts = probe_ctx->m_stats.m_num_conflicts;
118+
unsigned decisions = probe_ctx->m_stats.m_num_decisions;
119119

120-
// I can't figure out how to access the statistics fields, I only see an update method
121-
// st.get_uint("conflicts", conflicts);
122-
// st.get_uint("decisions", decisions);
123-
// st.get_uint("rlimit count", rlimit);
124-
score += conflicts + decisions + rlimit;
120+
score += conflicts + decisions;
125121
}
126122

127123
if (i == 0 || score < best_score) {
@@ -134,49 +130,40 @@ namespace smt {
134130
}
135131

136132
void parallel::param_generator::init_param_state() {
137-
// param_descrs smt_desc;
138-
// smt_params_helper::collect_param_descrs(smt_desc);
139133
smt_params_helper smtp(m_p);
140-
m_my_param_state.insert(symbol("smt.arith.nl.branching"), smtp.arith_nl_branching());
141-
m_my_param_state.insert(symbol("smt.arith.nl.cross_nested"), smtp.arith_nl_cross_nested());
142-
m_my_param_state.insert(symbol("smt.arith.nl.delay"), smtp.arith_nl_delay());
143-
m_my_param_state.insert(symbol("smt.arith.nl.expensive_patching"), smtp.arith_nl_expensive_patching());
144-
m_my_param_state.insert(symbol("smt.arith.nl.gb"), smtp.arith_nl_gb());
145-
m_my_param_state.insert(symbol("smt.arith.nl.horner"), smtp.arith_nl_horner());
146-
m_my_param_state.insert(symbol("smt.arith.nl.horner_frequency"), smtp.arith_nl_horner_frequency());
147-
m_my_param_state.insert(symbol("smt.arith.nl.optimize_bounds"), smtp.arith_nl_optimize_bounds());
148-
m_my_param_state.insert(symbol("smt.arith.nl.propagate_linear_monomials"), smtp.arith_nl_propagate_linear_monomials());
149-
m_my_param_state.insert(symbol("smt.arith.nl.tangents"), smtp.arith_nl_tangents());
150-
};
134+
m_param_state.push_back({symbol("smt.arith.nl.branching"), smtp.arith_nl_branching()});
135+
m_param_state.push_back({symbol("smt.arith.nl.cross_nested"), smtp.arith_nl_cross_nested()});
136+
m_param_state.push_back({symbol("smt.arith.nl.delay"), unsigned_value({smtp.arith_nl_delay(), 5, 10})});
137+
m_param_state.push_back({symbol("smt.arith.nl.expensive_patching"), smtp.arith_nl_expensive_patching()});
138+
m_param_state.push_back({symbol("smt.arith.nl.gb"), smtp.arith_nl_grobner()});
139+
m_param_state.push_back({symbol("smt.arith.nl.horner"), smtp.arith_nl_horner()});
140+
m_param_state.push_back({symbol("smt.arith.nl.horner_frequency"), unsigned_value({smtp.arith_nl_horner_frequency(), 2, 6})
141+
});
142+
m_param_state.push_back({symbol("smt.arith.nl.optimize_bounds"), smtp.arith_nl_optimize_bounds()});
143+
m_param_state.push_back(
144+
{symbol("smt.arith.nl.propagate_linear_monomials"), smtp.arith_nl_propagate_linear_monomials()});
145+
m_param_state.push_back({symbol("smt.arith.nl.tangents"), smtp.arith_nl_tangents()});
151146

152-
// TODO: this should mutate only one field at a time an mutate it based on m_my_param_state to keep it generic.
153-
154-
smt_params parallel::param_generator::mutate_param_state() {
155-
smt_params p = m_param_state;
156-
random_gen m_rand;
157-
158-
auto flip_bool = [&](bool &x) {
159-
if (m_rand(2) == 0)
160-
x = !x;
161-
};
147+
};
162148

163-
auto mutate_uint = [&](unsigned &x, unsigned lo, unsigned hi) {
164-
if ((m_rand() % 2) == 0)
165-
x = lo + (m_rand((hi - lo + 1)));
166-
};
149+
parallel::param_generator::param_values parallel::param_generator::mutate_param_state() {
167150

168-
flip_bool(p.m_nl_arith_branching);
169-
flip_bool(p.m_nl_arith_cross_nested);
170-
mutate_uint(p.m_nl_arith_delay, 5, 20);
171-
flip_bool(p.m_nl_arith_expensive_patching);
172-
flip_bool(p.m_nl_arith_gb);
173-
flip_bool(p.m_nl_arith_horner);
174-
mutate_uint(p.m_nl_arith_horner_frequency, 2, 6);
175-
flip_bool(p.m_nl_arith_optimize_bounds);
176-
flip_bool(p.m_nl_arith_propagate_linear_monomials);
177-
flip_bool(p.m_nl_arith_tangents);
178-
179-
return p;
151+
param_values new_param_values(m_param_state);
152+
unsigned index = ctx->get_random_value() % new_param_values.size();
153+
auto &param = new_param_values[index];
154+
if (std::holds_alternative<bool>(param.second)) {
155+
bool value = *std::get_if<bool>(&param.second);
156+
param.second = !value;
157+
}
158+
else if (std::holds_alternative<unsigned_value>(param.second)) {
159+
auto [value, lo, hi] = *std::get_if<unsigned_value>(&param.second);
160+
unsigned new_value = value;
161+
while (new_value == value) {
162+
new_value = lo + ctx->get_random_value() % (hi - lo + 1);
163+
}
164+
std::get<unsigned_value>(param.second).value = new_value;
165+
}
166+
return new_param_values;
180167
}
181168

182169
void parallel::param_generator::protocol_iteration() {
@@ -185,6 +172,8 @@ namespace smt {
185172

186173
// copy current param state to all param probe contexts, before running the next prefix step
187174
// this ensures that each param probe context replays the prefix from the same configuration
175+
176+
// instead just one one context and reset it each time before copy.
188177
for (unsigned i = 0; i < m_param_probe_contexts.size(); ++i) {
189178
context::copy(*ctx, *m_param_probe_contexts[i], true);
190179
}
@@ -195,8 +184,12 @@ namespace smt {
195184
case l_undef: {
196185
// TODO, change from smt_params to a generic param state representation based on params_ref
197186
// only params_ref have effect on updates.
198-
smt_params best_param_state = m_param_state;
199-
vector<smt_params> candidate_param_states;
187+
param_values best_param_state = m_param_state;
188+
vector<param_values> candidate_param_states;
189+
190+
// you can create the mutations on the fly and get the scores
191+
// you don't have to copy all over each tester.
192+
200193

201194
candidate_param_states.push_back(best_param_state); // first candidate param state is current best
202195
while (candidate_param_states.size() <= N) {
@@ -207,7 +200,8 @@ namespace smt {
207200

208201
if (best_param_state_idx != 0) {
209202
m_param_state = candidate_param_states[best_param_state_idx];
210-
b.set_param_state(m_param_state);
203+
auto p = apply_param_values(m_param_state);
204+
b.set_param_state(p);
211205
IF_VERBOSE(1, verbose_stream() << " PARAM TUNER found better param state at index " << best_param_state_idx << "\n");
212206
} else {
213207
IF_VERBOSE(1, verbose_stream() << " PARAM TUNER retained current param state\n");
@@ -318,12 +312,12 @@ namespace smt {
318312
}
319313

320314
parallel::param_generator::param_generator(parallel& p)
321-
: p(p), b(p.m_batch_manager), m_param_state(p.ctx.get_fparams()), m_p(p.ctx.get_params()), m_l2g(m, p.ctx.m) {
322-
ctx = alloc(context, m, m_param_state, m_p);
315+
: p(p), b(p.m_batch_manager), m_p(p.ctx.get_params()), m_l2g(m, p.ctx.m) {
316+
ctx = alloc(context, m, p.ctx.get_fparams(), m_p);
323317
context::copy(p.ctx, *ctx, true);
324318

325319
for (unsigned i = 0; i < N; ++i) {
326-
m_param_probe_contexts.push_back(alloc(context, m, m_param_state, m_p));
320+
m_param_probe_contexts.push_back(alloc(context, m, ctx->get_fparams(), m_p));
327321
}
328322

329323
// don't share initial units
@@ -483,7 +477,8 @@ namespace smt {
483477
}
484478
}
485479

486-
smt_params parallel::batch_manager::get_best_param_state() {
480+
// todo make this thread safe by not using reference counts implicit in params ref but instead copying the entire structure.
481+
params_ref parallel::batch_manager::get_best_param_state() {
487482
std::scoped_lock lock(mux);
488483
return m_param_state;
489484
}

src/smt/smt_parallel.h

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ namespace smt {
8181
std::mutex mux;
8282
state m_state = state::is_running;
8383
stats m_stats;
84-
smt_params m_param_state;
84+
params_ref m_param_state;
8585
using node = search_tree::node<cube_config>;
8686
search_tree::tree<cube_config> m_search_tree;
8787

@@ -106,10 +106,10 @@ namespace smt {
106106
void set_sat(ast_translation& l2g, model& m);
107107
void set_exception(std::string const& msg);
108108
void set_exception(unsigned error_code);
109-
void set_param_state(smt_params const& p) { m_param_state = p; }
109+
void set_param_state(params_ref const& p) { m_param_state.copy(p); }
110110
void collect_statistics(::statistics& st) const;
111111

112-
smt_params get_best_param_state();
112+
params_ref get_best_param_state();
113113
bool get_cube(ast_translation& g2l, unsigned id, expr_ref_vector& cube, node*& n);
114114
void backtrack(ast_translation& l2g, expr_ref_vector const& core, node* n);
115115
void split(ast_translation& l2g, unsigned id, node* n, expr* atom);
@@ -139,22 +139,32 @@ namespace smt {
139139

140140
scoped_ptr<context> m_prefix_solver;
141141
scoped_ptr_vector<context> m_param_probe_contexts;
142-
smt_params m_param_state;
143142
params_ref m_p;
144143

145-
using param_value = std::variant<unsigned, bool, double>;
146-
symbol_table<param_value> m_my_param_state;
144+
struct unsigned_value {
145+
unsigned value;
146+
unsigned min_value;
147+
unsigned max_value;
148+
};
149+
using param_value = std::variant<unsigned_value, bool>;
150+
using param_values = vector<std::pair<symbol, param_value>>;
151+
param_values m_param_state;
152+
153+
params_ref apply_param_values(param_values const &pv) {
154+
return m_p;
155+
}
156+
// todo
147157

148158
private:
149159
void init_param_state();
150160

151-
smt_params mutate_param_state();
161+
param_values mutate_param_state();
152162

153163
public:
154164
param_generator(parallel &p);
155165
lbool run_prefix_step();
156166
void protocol_iteration();
157-
unsigned replay_proof_prefixes(vector<smt_params> candidate_param_states, unsigned max_conflicts_epsilon);
167+
unsigned replay_proof_prefixes(vector<param_values> const& candidate_param_states, unsigned max_conflicts_epsilon);
158168

159169
reslimit &limit() {
160170
return m.limit();

0 commit comments

Comments
 (0)