Skip to content

Commit 09ca12d

Browse files
committed
hard copy x before passing it to the solver impl for powell. Use a chainable pointer for the arguments passed to the user defined function (UDF).
1 parent 0b93db8 commit 09ca12d

File tree

2 files changed

+12
-13
lines changed

2 files changed

+12
-13
lines changed

stan/math/prim/functor/algebra_solver_adapter.hpp

+3-4
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,9 @@ struct algebra_solver_adapter {
1717

1818
explicit algebra_solver_adapter(const F& f) : f_(f) {}
1919

20-
template <typename T1, typename T2, typename T3, typename T4>
21-
auto operator()(const T1& x, std::ostream* msgs, const T2& y, const T3& dat,
22-
const T4& dat_int) const {
23-
return f_(x, y, dat, dat_int, msgs);
20+
template <typename T1, typename... T2>
21+
auto operator()(const T1& x, std::ostream* msgs, T2&&... args) const {
22+
return f_(x, args..., msgs);
2423
}
2524
};
2625

stan/math/rev/functor/algebra_solver_powell.hpp

+9-9
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ T& algebra_solver_powell_call_solver(const F& f, T& x, std::ostream* const msgs,
5151
const int64_t max_num_steps,
5252
const Args&... args) {
5353
// Construct the solver
54-
hybrj_functor_solver<decltype(f)> hfs(f);
55-
Eigen::HybridNonLinearSolver<decltype(hfs)> solver(hfs);
54+
hybrj_functor_solver<F> hfs(f);
55+
Eigen::HybridNonLinearSolver<hybrj_functor_solver<F>> solver(hfs);
5656

5757
// Compute theta_dbl
5858
solver.parameters.xtol = relative_tolerance;
@@ -61,7 +61,7 @@ T& algebra_solver_powell_call_solver(const F& f, T& x, std::ostream* const msgs,
6161

6262
// Check if the max number of steps has been exceeded
6363
if (solver.nfev >= max_num_steps) {
64-
[&]() STAN_COLD_PATH {
64+
[max_num_steps]() STAN_COLD_PATH {
6565
throw_domain_error("algebra_solver", "maximum number of iterations",
6666
max_num_steps, "(", ") was exceeded in the solve.");
6767
}();
@@ -70,7 +70,7 @@ T& algebra_solver_powell_call_solver(const F& f, T& x, std::ostream* const msgs,
7070
// Check solution is a root
7171
double system_norm = f(x).stableNorm();
7272
if (system_norm > function_tolerance) {
73-
[&]() STAN_COLD_PATH {
73+
[function_tolerance, system_norm]() STAN_COLD_PATH {
7474
std::ostringstream message;
7575
message << "the norm of the algebraic function is " << system_norm
7676
<< " but should be lower than the function "
@@ -132,7 +132,7 @@ Eigen::VectorXd algebra_solver_powell_impl(const F& f, const T& x,
132132
const double function_tolerance,
133133
const int64_t max_num_steps,
134134
const Args&... args) {
135-
plain_type_t<decltype(to_ref(value_of(x)))> x_ref = to_ref(value_of(x));
135+
auto x_ref = eval(value_of(x));
136136
auto args_vals_tuple = std::make_tuple(to_ref(args)...);
137137

138138
auto f_wrt_x = [&args_vals_tuple, &f, msgs](const auto& x) {
@@ -336,13 +336,13 @@ Eigen::Matrix<var, Eigen::Dynamic, 1> algebra_solver_powell_impl(
336336
const F& f, const T& x, std::ostream* const msgs,
337337
const double relative_tolerance, const double function_tolerance,
338338
const int64_t max_num_steps, const T_Args&... args) {
339-
plain_type_t<decltype(to_ref(value_of(x)))> x_ref = to_ref(value_of(x));
340-
auto arena_args_tuple = std::make_tuple(to_arena(args)...);
339+
auto x_ref = eval(value_of(x));
340+
auto arena_args_tuple = make_chainable_ptr(std::make_tuple(eval(args)...));
341341
auto args_vals_tuple = apply(
342342
[&](const auto&... args) {
343343
return std::make_tuple(to_ref(value_of(args))...);
344344
},
345-
arena_args_tuple);
345+
*arena_args_tuple);
346346

347347
auto f_wrt_x = [&args_vals_tuple, &f, msgs](const auto& x) {
348348
return apply(
@@ -386,7 +386,7 @@ Eigen::Matrix<var, Eigen::Dynamic, 1> algebra_solver_powell_impl(
386386
Eigen::VectorXd ret_val = ret.val();
387387
auto x_nrad_ = apply(
388388
[&](const auto&... args) { return eval(f(ret_val, msgs, args...)); },
389-
arena_args_tuple);
389+
*arena_args_tuple);
390390
x_nrad_.adj() = eta;
391391
grad();
392392
}

0 commit comments

Comments
 (0)