3
3
4
4
#include < stan/math/rev/core.hpp>
5
5
#include < stan/math/rev/functor/algebra_system.hpp>
6
- #include < stan/math/rev/functor/algebra_solver_powell.hpp>
7
6
#include < stan/math/rev/functor/kinsol_solve.hpp>
8
7
#include < stan/math/prim/err.hpp>
9
- #include < stan/math/prim/fun/mdivide_left.hpp>
10
8
#include < stan/math/prim/fun/value_of.hpp>
9
+ #include < stan/math/prim/functor/algebra_solver_adapter.hpp>
11
10
#include < unsupported/Eigen/NonLinearOptimization>
12
11
#include < iostream>
13
12
#include < string>
@@ -25,53 +24,180 @@ namespace math {
25
24
* The user can also specify the scaled step size, the function
26
25
* tolerance, and the maximum number of steps.
27
26
*
27
+ * This overload handles non-autodiff parameters.
28
+ *
28
29
* @tparam F type of equation system function.
29
30
* @tparam T type of initial guess vector.
31
+ * @tparam Args types of additional parameters to the equation system functor
30
32
*
31
33
* @param[in] f Functor that evaluated the system of equations.
32
34
* @param[in] x Vector of starting values.
33
- * @param[in] y Parameter vector for the equation system. The function
34
- * is overloaded to treat y as a vector of doubles or of a
35
- * a template type T.
36
- * @param[in] dat Continuous data vector for the equation system.
37
- * @param[in] dat_int Integer data vector for the equation system.
38
35
* @param[in, out] msgs The print stream for warning messages.
39
36
* @param[in] scaling_step_size Scaled-step stopping tolerance. If
40
37
* a Newton step is smaller than the scaling step
41
38
* tolerance, the code breaks, assuming the solver is no
42
39
* longer making significant progress (i.e. is stuck)
43
40
* @param[in] function_tolerance determines whether roots are acceptable.
44
41
* @param[in] max_num_steps maximum number of function evaluations.
45
- * * @throw <code>std::invalid_argument</code> if x has size zero.
42
+ * @param[in] args Additional parameters to the equation system functor.
43
+ * @return theta Vector of solutions to the system of equations.
44
+ * @pre f returns finite values when passed any value of x and the given args.
45
+ * @throw <code>std::invalid_argument</code> if x has size zero.
46
+ * @throw <code>std::invalid_argument</code> if x has non-finite elements.
47
+ * @throw <code>std::invalid_argument</code> if scaled_step_size is strictly
48
+ * negative.
49
+ * @throw <code>std::invalid_argument</code> if function_tolerance is strictly
50
+ * negative.
51
+ * @throw <code>std::invalid_argument</code> if max_num_steps is not positive.
52
+ * @throw <code>std::domain_error if solver exceeds max_num_steps.
53
+ */
54
+ template <typename F, typename T, typename ... Args,
55
+ require_eigen_vector_t <T>* = nullptr ,
56
+ require_all_st_arithmetic<Args...>* = nullptr >
57
+ Eigen::VectorXd algebra_solver_newton_impl (const F& f, const T& x,
58
+ std::ostream* const msgs,
59
+ const double scaling_step_size,
60
+ const double function_tolerance,
61
+ const int64_t max_num_steps,
62
+ const Args&... args) {
63
+ const auto & x_ref = to_ref (value_of (x));
64
+
65
+ check_nonzero_size (" algebra_solver_newton" , " initial guess" , x_ref);
66
+ check_finite (" algebra_solver_newton" , " initial guess" , x_ref);
67
+ check_nonnegative (" algebra_solver_newton" , " scaling_step_size" ,
68
+ scaling_step_size);
69
+ check_nonnegative (" algebra_solver_newton" , " function_tolerance" ,
70
+ function_tolerance);
71
+ check_positive (" algebra_solver_newton" , " max_num_steps" , max_num_steps);
72
+
73
+ return kinsol_solve (f, x_ref, scaling_step_size, function_tolerance,
74
+ max_num_steps, 1 , 10 , KIN_LINESEARCH, msgs, args...);
75
+ }
76
+
77
+ /* *
78
+ * Return the solution to the specified system of algebraic
79
+ * equations given an initial guess, and parameters and data,
80
+ * which get passed into the algebraic system. Use the
81
+ * KINSOL solver from the SUNDIALS suite.
82
+ *
83
+ * The user can also specify the scaled step size, the function
84
+ * tolerance, and the maximum number of steps.
85
+ *
86
+ * This overload handles var parameters.
87
+ *
88
+ * The Jacobian \(J_{xy}\) (i.e., Jacobian of unknown \(x\) w.r.t. the parameter
89
+ * \(y\)) is calculated given the solution as follows. Since
90
+ * \[
91
+ * f(x, y) = 0,
92
+ * \]
93
+ * we have (\(J_{pq}\) being the Jacobian matrix \(\tfrac {dq} {dq}\))
94
+ * \[
95
+ * - J_{fx} J_{xy} = J_{fy},
96
+ * \]
97
+ * and therefore \(J_{xy}\) can be solved from system
98
+ * \[
99
+ * - J_{fx} J_{xy} = J_{fy}.
100
+ * \]
101
+ * Let \(eta\) be the adjoint with respect to \(x\); then to calculate
102
+ * \[
103
+ * \eta J_{xy},
104
+ * \]
105
+ * we solve
106
+ * \[
107
+ * - (\eta J_{fx}^{-1}) J_{fy}.
108
+ * \]
109
+ *
110
+ * @tparam F type of equation system function.
111
+ * @tparam T type of initial guess vector.
112
+ * @tparam Args types of additional parameters to the equation system functor
113
+ *
114
+ * @param[in] f Functor that evaluated the system of equations.
115
+ * @param[in] x Vector of starting values.
116
+ * @param[in, out] msgs The print stream for warning messages.
117
+ * @param[in] scaling_step_size Scaled-step stopping tolerance. If
118
+ * a Newton step is smaller than the scaling step
119
+ * tolerance, the code breaks, assuming the solver is no
120
+ * longer making significant progress (i.e. is stuck)
121
+ * @param[in] function_tolerance determines whether roots are acceptable.
122
+ * @param[in] max_num_steps maximum number of function evaluations.
123
+ * @param[in] args Additional parameters to the equation system functor.
124
+ * @return theta Vector of solutions to the system of equations.
125
+ * @pre f returns finite values when passed any value of x and the given args.
126
+ * @throw <code>std::invalid_argument</code> if x has size zero.
46
127
* @throw <code>std::invalid_argument</code> if x has non-finite elements.
47
- * @throw <code>std::invalid_argument</code> if y has non-finite elements.
48
- * @throw <code>std::invalid_argument</code> if dat has non-finite elements.
49
- * @throw <code>std::invalid_argument</code> if dat_int has non-finite elements.
50
128
* @throw <code>std::invalid_argument</code> if scaled_step_size is strictly
51
129
* negative.
52
130
* @throw <code>std::invalid_argument</code> if function_tolerance is strictly
53
131
* negative.
54
132
* @throw <code>std::invalid_argument</code> if max_num_steps is not positive.
55
- * @throw <code>std::domain_error</code> if solver exceeds max_num_steps.
133
+ * @throw <code>std::domain_error if solver exceeds max_num_steps.
56
134
*/
57
- template <typename F, typename T, require_eigen_vector_t <T>* = nullptr >
58
- Eigen::VectorXd algebra_solver_newton (
59
- const F& f, const T& x, const Eigen::VectorXd& y,
60
- const std::vector<double >& dat, const std::vector<int >& dat_int,
61
- std::ostream* msgs = nullptr , double scaling_step_size = 1e-3 ,
62
- double function_tolerance = 1e-6 ,
63
- long int max_num_steps = 200 ) { // NOLINT(runtime/int)
64
- const auto & x_eval = x.eval ();
65
- algebra_solver_check (x_eval, y, dat, dat_int, function_tolerance,
66
- max_num_steps);
67
- check_nonnegative (" algebra_solver" , " scaling_step_size" , scaling_step_size);
68
-
69
- check_matching_sizes (" algebra_solver" , " the algebraic system's output" ,
70
- value_of (f (x_eval, y, dat, dat_int, msgs)),
71
- " the vector of unknowns, x," , x);
72
-
73
- return kinsol_solve (f, value_of (x_eval), y, dat, dat_int, 0 ,
74
- scaling_step_size, function_tolerance, max_num_steps);
135
+ template <typename F, typename T, typename ... T_Args,
136
+ require_eigen_vector_t <T>* = nullptr ,
137
+ require_any_st_var<T_Args...>* = nullptr >
138
+ Eigen::Matrix<var, Eigen::Dynamic, 1 > algebra_solver_newton_impl (
139
+ const F& f, const T& x, std::ostream* const msgs,
140
+ const double scaling_step_size, const double function_tolerance,
141
+ const int64_t max_num_steps, const T_Args&... args) {
142
+ const auto & x_ref = to_ref (value_of (x));
143
+ auto arena_args_tuple = std::make_tuple (to_arena (args)...);
144
+ auto args_vals_tuple = apply (
145
+ [&](const auto &... args) {
146
+ return std::make_tuple (to_ref (value_of (args))...);
147
+ },
148
+ arena_args_tuple);
149
+
150
+ check_nonzero_size (" algebra_solver_newton" , " initial guess" , x_ref);
151
+ check_finite (" algebra_solver_newton" , " initial guess" , x_ref);
152
+ check_nonnegative (" algebra_solver_newton" , " scaling_step_size" ,
153
+ scaling_step_size);
154
+ check_nonnegative (" algebra_solver_newton" , " function_tolerance" ,
155
+ function_tolerance);
156
+ check_positive (" algebra_solver_newton" , " max_num_steps" , max_num_steps);
157
+
158
+ // Solve the system
159
+ Eigen::VectorXd theta_dbl = apply (
160
+ [&](const auto &... vals) {
161
+ return kinsol_solve (f, x_ref, scaling_step_size, function_tolerance,
162
+ max_num_steps, 1 , 10 , KIN_LINESEARCH, msgs,
163
+ vals...);
164
+ },
165
+ args_vals_tuple);
166
+
167
+ auto f_wrt_x = [&](const auto & x) {
168
+ return apply ([&](const auto &... args) { return f (x, msgs, args...); },
169
+ args_vals_tuple);
170
+ };
171
+
172
+ Eigen::MatrixXd Jf_x;
173
+ Eigen::VectorXd f_x;
174
+
175
+ jacobian (f_wrt_x, theta_dbl, f_x, Jf_x);
176
+
177
+ using ret_type = Eigen::Matrix<var, Eigen::Dynamic, -1 >;
178
+ arena_t <ret_type> ret = theta_dbl;
179
+ auto Jf_x_T_lu_ptr
180
+ = make_unsafe_chainable_ptr (Jf_x.transpose ().partialPivLu ()); // Lu
181
+
182
+ reverse_pass_callback ([f, ret, arena_args_tuple, Jf_x_T_lu_ptr,
183
+ msgs]() mutable {
184
+ Eigen::VectorXd eta = -Jf_x_T_lu_ptr->solve (ret.adj ().eval ());
185
+
186
+ // Contract with Jacobian of f with respect to y using a nested reverse
187
+ // autodiff pass.
188
+ {
189
+ nested_rev_autodiff rev;
190
+
191
+ Eigen::VectorXd ret_val = ret.val ();
192
+ auto x_nrad_ = apply (
193
+ [&](const auto &... args) { return eval (f (ret_val, msgs, args...)); },
194
+ arena_args_tuple);
195
+ x_nrad_.adj () = eta;
196
+ grad ();
197
+ }
198
+ });
199
+
200
+ return ret_type (ret);
75
201
}
76
202
77
203
/* *
@@ -83,19 +209,15 @@ Eigen::VectorXd algebra_solver_newton(
83
209
* The user can also specify the scaled step size, the function
84
210
* tolerance, and the maximum number of steps.
85
211
*
86
- * Overload the previous definition to handle the case where y
87
- * is a vector of parameters (var). The overload calls the
88
- * algebraic solver defined above and builds a vari object on
89
- * top, using the algebra_solver_vari class.
212
+ * Signature to maintain backward compatibility, will be removed
213
+ * in the future.
90
214
*
91
215
* @tparam F type of equation system function.
92
216
* @tparam T type of initial guess vector.
93
217
*
94
218
* @param[in] f Functor that evaluated the system of equations.
95
219
* @param[in] x Vector of starting values.
96
- * @param[in] y Parameter vector for the equation system. The function
97
- * is overloaded to treat y as a vector of doubles or of a
98
- * a template type T.
220
+ * @param[in] y Parameter vector for the equation system.
99
221
* @param[in] dat Continuous data vector for the equation system.
100
222
* @param[in] dat_int Integer data vector for the equation system.
101
223
* @param[in, out] msgs The print stream for warning messages.
@@ -119,34 +241,16 @@ Eigen::VectorXd algebra_solver_newton(
119
241
* @throw <code>std::domain_error if solver exceeds max_num_steps.
120
242
*/
121
243
template <typename F, typename T1, typename T2,
122
- require_all_eigen_vector_t <T1, T2>* = nullptr ,
123
- require_st_var<T2>* = nullptr >
244
+ require_all_eigen_vector_t <T1, T2>* = nullptr >
124
245
Eigen::Matrix<scalar_type_t <T2>, Eigen::Dynamic, 1 > algebra_solver_newton (
125
246
const F& f, const T1& x, const T2& y, const std::vector<double >& dat,
126
- const std::vector<int >& dat_int, std::ostream* msgs = nullptr ,
127
- double scaling_step_size = 1e-3 , double function_tolerance = 1e-6 ,
128
- long int max_num_steps = 200 ) { // NOLINT(runtime/int)
129
-
130
- const auto & x_eval = x.eval ();
131
- const auto & y_eval = y.eval ();
132
- Eigen::VectorXd theta_dbl = algebra_solver_newton (
133
- f, x_eval, value_of (y_eval), dat, dat_int, msgs, scaling_step_size,
134
- function_tolerance, max_num_steps);
135
-
136
- typedef system_functor<F, double , double , false > Fy;
137
- typedef system_functor<F, double , double , true > Fs;
138
- typedef hybrj_functor_solver<Fs, F, double , double > Fx;
139
- Fx fx (Fs (), f, value_of (x_eval), value_of (y_eval), dat, dat_int, msgs);
140
-
141
- // Construct vari
142
- auto * vi0 = new algebra_solver_vari<Fy, F, scalar_type_t <T2>, Fx>(
143
- Fy (), f, value_of (x_eval), y_eval, dat, dat_int, theta_dbl, fx, msgs);
144
- Eigen::Matrix<var, Eigen::Dynamic, 1 > theta (x.size ());
145
- theta (0 ) = var (vi0->theta_ [0 ]);
146
- for (int i = 1 ; i < x.size (); ++i)
147
- theta (i) = var (vi0->theta_ [i]);
148
-
149
- return theta;
247
+ const std::vector<int >& dat_int, std::ostream* const msgs = nullptr ,
248
+ const double scaling_step_size = 1e-3 ,
249
+ const double function_tolerance = 1e-6 ,
250
+ const long int max_num_steps = 200 ) { // NOLINT(runtime/int)
251
+ return algebra_solver_newton_impl (algebra_solver_adapter<F>(f), x, msgs,
252
+ scaling_step_size, function_tolerance,
253
+ max_num_steps, y, dat, dat_int);
150
254
}
151
255
152
256
} // namespace math
0 commit comments