@@ -83,7 +83,7 @@ namespace smt {
8383 return r;
8484 }
8585
86- unsigned parallel::param_generator::replay_proof_prefixes (vector<smt_params> candidate_param_states, unsigned max_conflicts_epsilon=200 ) {
86+ unsigned parallel::param_generator::replay_proof_prefixes (vector<param_values> const & candidate_param_states, unsigned max_conflicts_epsilon=200 ) {
8787 unsigned conflict_budget = m_max_prefix_conflicts + max_conflicts_epsilon;
8888 unsigned best_param_state_idx;
8989 double best_score;
@@ -95,11 +95,11 @@ namespace smt {
9595 double score = 0.0 ;
9696
9797 // apply the ith param state to probe_ctx
98- smt_params params = candidate_param_states[i];
99- params_ref p;
100- params.updt_params (p);
98+ params_ref p = apply_param_values (candidate_param_states[i]);
10199 probe_ctx->updt_params (p);
102100
101+ // todo: m_recorded_cubes as a expr_ref_vector
102+
103103 for (auto const & clause : probe_ctx->m_recorded_clauses ) {
104104 expr_ref_vector negated_lits (probe_ctx->m );
105105 for (literal lit : clause) {
@@ -111,17 +111,13 @@ namespace smt {
111111 }
112112
113113 // Replay the negated clause
114+
114115 lbool r = probe_ctx->check (negated_lits.size (), negated_lits.data ());
115116
116- ::statistics st;
117- probe_ctx->collect_statistics (st);
118- unsigned conflicts = 0 , decisions = 0 , rlimit = 0 ;
117+ unsigned conflicts = probe_ctx->m_stats .m_num_conflicts ;
118+ unsigned decisions = probe_ctx->m_stats .m_num_decisions ;
119119
120- // I can't figure out how to access the statistics fields, I only see an update method
121- // st.get_uint("conflicts", conflicts);
122- // st.get_uint("decisions", decisions);
123- // st.get_uint("rlimit count", rlimit);
124- score += conflicts + decisions + rlimit;
120+ score += conflicts + decisions;
125121 }
126122
127123 if (i == 0 || score < best_score) {
@@ -134,49 +130,40 @@ namespace smt {
134130 }
135131
136132 void parallel::param_generator::init_param_state () {
137- // param_descrs smt_desc;
138- // smt_params_helper::collect_param_descrs(smt_desc);
139133 smt_params_helper smtp (m_p);
140- m_my_param_state.insert (symbol (" smt.arith.nl.branching" ), smtp.arith_nl_branching ());
141- m_my_param_state.insert (symbol (" smt.arith.nl.cross_nested" ), smtp.arith_nl_cross_nested ());
142- m_my_param_state.insert (symbol (" smt.arith.nl.delay" ), smtp.arith_nl_delay ());
143- m_my_param_state.insert (symbol (" smt.arith.nl.expensive_patching" ), smtp.arith_nl_expensive_patching ());
144- m_my_param_state.insert (symbol (" smt.arith.nl.gb" ), smtp.arith_nl_gb ());
145- m_my_param_state.insert (symbol (" smt.arith.nl.horner" ), smtp.arith_nl_horner ());
146- m_my_param_state.insert (symbol (" smt.arith.nl.horner_frequency" ), smtp.arith_nl_horner_frequency ());
147- m_my_param_state.insert (symbol (" smt.arith.nl.optimize_bounds" ), smtp.arith_nl_optimize_bounds ());
148- m_my_param_state.insert (symbol (" smt.arith.nl.propagate_linear_monomials" ), smtp.arith_nl_propagate_linear_monomials ());
149- m_my_param_state.insert (symbol (" smt.arith.nl.tangents" ), smtp.arith_nl_tangents ());
150- };
134+ m_param_state.push_back ({symbol (" smt.arith.nl.branching" ), smtp.arith_nl_branching ()});
135+ m_param_state.push_back ({symbol (" smt.arith.nl.cross_nested" ), smtp.arith_nl_cross_nested ()});
136+ m_param_state.push_back ({symbol (" smt.arith.nl.delay" ), unsigned_value ({smtp.arith_nl_delay (), 5 , 10 })});
137+ m_param_state.push_back ({symbol (" smt.arith.nl.expensive_patching" ), smtp.arith_nl_expensive_patching ()});
138+ m_param_state.push_back ({symbol (" smt.arith.nl.gb" ), smtp.arith_nl_grobner ()});
139+ m_param_state.push_back ({symbol (" smt.arith.nl.horner" ), smtp.arith_nl_horner ()});
140+ m_param_state.push_back ({symbol (" smt.arith.nl.horner_frequency" ), unsigned_value ({smtp.arith_nl_horner_frequency (), 2 , 6 })
141+ });
142+ m_param_state.push_back ({symbol (" smt.arith.nl.optimize_bounds" ), smtp.arith_nl_optimize_bounds ()});
143+ m_param_state.push_back (
144+ {symbol (" smt.arith.nl.propagate_linear_monomials" ), smtp.arith_nl_propagate_linear_monomials ()});
145+ m_param_state.push_back ({symbol (" smt.arith.nl.tangents" ), smtp.arith_nl_tangents ()});
151146
152- // TODO: this should mutate only one field at a time an mutate it based on m_my_param_state to keep it generic.
153-
154- smt_params parallel::param_generator::mutate_param_state () {
155- smt_params p = m_param_state;
156- random_gen m_rand;
157-
158- auto flip_bool = [&](bool &x) {
159- if (m_rand (2 ) == 0 )
160- x = !x;
161- };
147+ };
162148
163- auto mutate_uint = [&](unsigned &x, unsigned lo, unsigned hi) {
164- if ((m_rand () % 2 ) == 0 )
165- x = lo + (m_rand ((hi - lo + 1 )));
166- };
149+ parallel::param_generator::param_values parallel::param_generator::mutate_param_state () {
167150
168- flip_bool (p.m_nl_arith_branching );
169- flip_bool (p.m_nl_arith_cross_nested );
170- mutate_uint (p.m_nl_arith_delay , 5 , 20 );
171- flip_bool (p.m_nl_arith_expensive_patching );
172- flip_bool (p.m_nl_arith_gb );
173- flip_bool (p.m_nl_arith_horner );
174- mutate_uint (p.m_nl_arith_horner_frequency , 2 , 6 );
175- flip_bool (p.m_nl_arith_optimize_bounds );
176- flip_bool (p.m_nl_arith_propagate_linear_monomials );
177- flip_bool (p.m_nl_arith_tangents );
178-
179- return p;
151+ param_values new_param_values (m_param_state);
152+ unsigned index = ctx->get_random_value () % new_param_values.size ();
153+ auto ¶m = new_param_values[index];
154+ if (std::holds_alternative<bool >(param.second )) {
155+ bool value = *std::get_if<bool >(¶m.second );
156+ param.second = !value;
157+ }
158+ else if (std::holds_alternative<unsigned_value>(param.second )) {
159+ auto [value, lo, hi] = *std::get_if<unsigned_value>(¶m.second );
160+ unsigned new_value = value;
161+ while (new_value == value) {
162+ new_value = lo + ctx->get_random_value () % (hi - lo + 1 );
163+ }
164+ std::get<unsigned_value>(param.second ).value = new_value;
165+ }
166+ return new_param_values;
180167 }
181168
182169 void parallel::param_generator::protocol_iteration () {
@@ -185,6 +172,8 @@ namespace smt {
185172
186173 // copy current param state to all param probe contexts, before running the next prefix step
187174 // this ensures that each param probe context replays the prefix from the same configuration
175+
176+ // instead just one one context and reset it each time before copy.
188177 for (unsigned i = 0 ; i < m_param_probe_contexts.size (); ++i) {
189178 context::copy (*ctx, *m_param_probe_contexts[i], true );
190179 }
@@ -195,8 +184,12 @@ namespace smt {
195184 case l_undef: {
196185 // TODO, change from smt_params to a generic param state representation based on params_ref
197186 // only params_ref have effect on updates.
198- smt_params best_param_state = m_param_state;
199- vector<smt_params> candidate_param_states;
187+ param_values best_param_state = m_param_state;
188+ vector<param_values> candidate_param_states;
189+
190+ // you can create the mutations on the fly and get the scores
191+ // you don't have to copy all over each tester.
192+
200193
201194 candidate_param_states.push_back (best_param_state); // first candidate param state is current best
202195 while (candidate_param_states.size () <= N) {
@@ -207,7 +200,8 @@ namespace smt {
207200
208201 if (best_param_state_idx != 0 ) {
209202 m_param_state = candidate_param_states[best_param_state_idx];
210- b.set_param_state (m_param_state);
203+ auto p = apply_param_values (m_param_state);
204+ b.set_param_state (p);
211205 IF_VERBOSE (1 , verbose_stream () << " PARAM TUNER found better param state at index " << best_param_state_idx << " \n " );
212206 } else {
213207 IF_VERBOSE (1 , verbose_stream () << " PARAM TUNER retained current param state\n " );
@@ -318,12 +312,12 @@ namespace smt {
318312 }
319313
320314 parallel::param_generator::param_generator (parallel& p)
321- : p(p), b(p.m_batch_manager), m_param_state(p.ctx.get_fparams()), m_p(p.ctx.get_params()), m_l2g(m, p.ctx.m) {
322- ctx = alloc (context, m, m_param_state , m_p);
315+ : p(p), b(p.m_batch_manager), m_p(p.ctx.get_params()), m_l2g(m, p.ctx.m) {
316+ ctx = alloc (context, m, p. ctx . get_fparams () , m_p);
323317 context::copy (p.ctx , *ctx, true );
324318
325319 for (unsigned i = 0 ; i < N; ++i) {
326- m_param_probe_contexts.push_back (alloc (context, m, m_param_state , m_p));
320+ m_param_probe_contexts.push_back (alloc (context, m, ctx-> get_fparams () , m_p));
327321 }
328322
329323 // don't share initial units
@@ -483,7 +477,8 @@ namespace smt {
483477 }
484478 }
485479
486- smt_params parallel::batch_manager::get_best_param_state () {
480+ // todo make this thread safe by not using reference counts implicit in params ref but instead copying the entire structure.
481+ params_ref parallel::batch_manager::get_best_param_state () {
487482 std::scoped_lock lock (mux);
488483 return m_param_state;
489484 }
0 commit comments