33
44#include < stan/math/rev/core.hpp>
55#include < stan/math/rev/functor/algebra_system.hpp>
6+ #include < stan/math/rev/functor/algebra_solver_powell.hpp>
67#include < stan/math/rev/functor/kinsol_solve.hpp>
78#include < stan/math/prim/err.hpp>
9+ #include < stan/math/prim/fun/mdivide_left.hpp>
810#include < stan/math/prim/fun/value_of.hpp>
9- #include < stan/math/prim/functor/algebra_solver_adapter.hpp>
1011#include < unsupported/Eigen/NonLinearOptimization>
1112#include < iostream>
1213#include < string>
@@ -24,180 +25,53 @@ namespace math {
2425 * The user can also specify the scaled step size, the function
2526 * tolerance, and the maximum number of steps.
2627 *
27- * This overload handles non-autodiff parameters.
28- *
29- * @tparam F type of equation system function.
30- * @tparam T type of initial guess vector.
31- * @tparam Args types of additional parameters to the equation system functor
32- *
33- * @param[in] f Functor that evaluated the system of equations.
34- * @param[in] x Vector of starting values.
35- * @param[in, out] msgs The print stream for warning messages.
36- * @param[in] scaling_step_size Scaled-step stopping tolerance. If
37- * a Newton step is smaller than the scaling step
38- * tolerance, the code breaks, assuming the solver is no
39- * longer making significant progress (i.e. is stuck)
40- * @param[in] function_tolerance determines whether roots are acceptable.
41- * @param[in] max_num_steps maximum number of function evaluations.
42- * @param[in] args Additional parameters to the equation system functor.
43- * @return theta Vector of solutions to the system of equations.
44- * @pre f returns finite values when passed any value of x and the given args.
45- * @throw <code>std::invalid_argument</code> if x has size zero.
46- * @throw <code>std::invalid_argument</code> if x has non-finite elements.
47- * @throw <code>std::invalid_argument</code> if scaled_step_size is strictly
48- * negative.
49- * @throw <code>std::invalid_argument</code> if function_tolerance is strictly
50- * negative.
51- * @throw <code>std::invalid_argument</code> if max_num_steps is not positive.
52- * @throw <code>std::domain_error if solver exceeds max_num_steps.
53- */
54- template <typename F, typename T, typename ... Args,
55- require_eigen_vector_t <T>* = nullptr ,
56- require_all_st_arithmetic<Args...>* = nullptr >
57- Eigen::VectorXd algebra_solver_newton_impl (const F& f, const T& x,
58- std::ostream* const msgs,
59- const double scaling_step_size,
60- const double function_tolerance,
61- const int64_t max_num_steps,
62- const Args&... args) {
63- const auto & x_ref = to_ref (value_of (x));
64-
65- check_nonzero_size (" algebra_solver_newton" , " initial guess" , x_ref);
66- check_finite (" algebra_solver_newton" , " initial guess" , x_ref);
67- check_nonnegative (" algebra_solver_newton" , " scaling_step_size" ,
68- scaling_step_size);
69- check_nonnegative (" algebra_solver_newton" , " function_tolerance" ,
70- function_tolerance);
71- check_positive (" algebra_solver_newton" , " max_num_steps" , max_num_steps);
72-
73- return kinsol_solve (f, x_ref, scaling_step_size, function_tolerance,
74- max_num_steps, 1 , 10 , KIN_LINESEARCH, msgs, args...);
75- }
76-
77- /* *
78- * Return the solution to the specified system of algebraic
79- * equations given an initial guess, and parameters and data,
80- * which get passed into the algebraic system. Use the
81- * KINSOL solver from the SUNDIALS suite.
82- *
83- * The user can also specify the scaled step size, the function
84- * tolerance, and the maximum number of steps.
85- *
86- * This overload handles var parameters.
87- *
88- * The Jacobian \(J_{xy}\) (i.e., Jacobian of unknown \(x\) w.r.t. the parameter
89- * \(y\)) is calculated given the solution as follows. Since
90- * \[
91- * f(x, y) = 0,
92- * \]
93- * we have (\(J_{pq}\) being the Jacobian matrix \(\tfrac {dq} {dq}\))
94- * \[
95- * - J_{fx} J_{xy} = J_{fy},
96- * \]
97- * and therefore \(J_{xy}\) can be solved from system
98- * \[
99- * - J_{fx} J_{xy} = J_{fy}.
100- * \]
101- * Let \(eta\) be the adjoint with respect to \(x\); then to calculate
102- * \[
103- * \eta J_{xy},
104- * \]
105- * we solve
106- * \[
107- * - (\eta J_{fx}^{-1}) J_{fy}.
108- * \]
109- *
11028 * @tparam F type of equation system function.
11129 * @tparam T type of initial guess vector.
112- * @tparam Args types of additional parameters to the equation system functor
11330 *
11431 * @param[in] f Functor that evaluated the system of equations.
11532 * @param[in] x Vector of starting values.
33+ * @param[in] y Parameter vector for the equation system. The function
34+ * is overloaded to treat y as a vector of doubles or of a
35+ * a template type T.
36+ * @param[in] dat Continuous data vector for the equation system.
37+ * @param[in] dat_int Integer data vector for the equation system.
11638 * @param[in, out] msgs The print stream for warning messages.
11739 * @param[in] scaling_step_size Scaled-step stopping tolerance. If
11840 * a Newton step is smaller than the scaling step
11941 * tolerance, the code breaks, assuming the solver is no
12042 * longer making significant progress (i.e. is stuck)
12143 * @param[in] function_tolerance determines whether roots are acceptable.
12244 * @param[in] max_num_steps maximum number of function evaluations.
123- * @param[in] args Additional parameters to the equation system functor.
124- * @return theta Vector of solutions to the system of equations.
125- * @pre f returns finite values when passed any value of x and the given args.
126- * @throw <code>std::invalid_argument</code> if x has size zero.
45+ * * @throw <code>std::invalid_argument</code> if x has size zero.
12746 * @throw <code>std::invalid_argument</code> if x has non-finite elements.
47+ * @throw <code>std::invalid_argument</code> if y has non-finite elements.
48+ * @throw <code>std::invalid_argument</code> if dat has non-finite elements.
49+ * @throw <code>std::invalid_argument</code> if dat_int has non-finite elements.
12850 * @throw <code>std::invalid_argument</code> if scaled_step_size is strictly
12951 * negative.
13052 * @throw <code>std::invalid_argument</code> if function_tolerance is strictly
13153 * negative.
13254 * @throw <code>std::invalid_argument</code> if max_num_steps is not positive.
133- * @throw <code>std::domain_error if solver exceeds max_num_steps.
55+ * @throw <code>std::domain_error</code> if solver exceeds max_num_steps.
13456 */
135- template <typename F, typename T, typename ... T_Args,
136- require_eigen_vector_t <T>* = nullptr ,
137- require_any_st_var<T_Args...>* = nullptr >
138- Eigen::Matrix<var, Eigen::Dynamic, 1 > algebra_solver_newton_impl (
139- const F& f, const T& x, std::ostream* const msgs,
140- const double scaling_step_size, const double function_tolerance,
141- const int64_t max_num_steps, const T_Args&... args) {
142- const auto & x_ref = to_ref (value_of (x));
143- auto arena_args_tuple = std::make_tuple (to_arena (args)...);
144- auto args_vals_tuple = apply (
145- [&](const auto &... args) {
146- return std::make_tuple (to_ref (value_of (args))...);
147- },
148- arena_args_tuple);
149-
150- check_nonzero_size (" algebra_solver_newton" , " initial guess" , x_ref);
151- check_finite (" algebra_solver_newton" , " initial guess" , x_ref);
152- check_nonnegative (" algebra_solver_newton" , " scaling_step_size" ,
153- scaling_step_size);
154- check_nonnegative (" algebra_solver_newton" , " function_tolerance" ,
155- function_tolerance);
156- check_positive (" algebra_solver_newton" , " max_num_steps" , max_num_steps);
157-
158- // Solve the system
159- Eigen::VectorXd theta_dbl = apply (
160- [&](const auto &... vals) {
161- return kinsol_solve (f, x_ref, scaling_step_size, function_tolerance,
162- max_num_steps, 1 , 10 , KIN_LINESEARCH, msgs,
163- vals...);
164- },
165- args_vals_tuple);
166-
167- auto f_wrt_x = [&](const auto & x) {
168- return apply ([&](const auto &... args) { return f (x, msgs, args...); },
169- args_vals_tuple);
170- };
171-
172- Eigen::MatrixXd Jf_x;
173- Eigen::VectorXd f_x;
174-
175- jacobian (f_wrt_x, theta_dbl, f_x, Jf_x);
176-
177- using ret_type = Eigen::Matrix<var, Eigen::Dynamic, -1 >;
178- arena_t <ret_type> ret = theta_dbl;
179- auto Jf_x_T_lu_ptr
180- = make_unsafe_chainable_ptr (Jf_x.transpose ().partialPivLu ()); // Lu
181-
182- reverse_pass_callback ([f, ret, arena_args_tuple, Jf_x_T_lu_ptr,
183- msgs]() mutable {
184- Eigen::VectorXd eta = -Jf_x_T_lu_ptr->solve (ret.adj ().eval ());
185-
186- // Contract with Jacobian of f with respect to y using a nested reverse
187- // autodiff pass.
188- {
189- nested_rev_autodiff rev;
190-
191- Eigen::VectorXd ret_val = ret.val ();
192- auto x_nrad_ = apply (
193- [&](const auto &... args) { return eval (f (ret_val, msgs, args...)); },
194- arena_args_tuple);
195- x_nrad_.adj () = eta;
196- grad ();
197- }
198- });
199-
200- return ret_type (ret);
57+ template <typename F, typename T, require_eigen_vector_t <T>* = nullptr >
58+ Eigen::VectorXd algebra_solver_newton (
59+ const F& f, const T& x, const Eigen::VectorXd& y,
60+ const std::vector<double >& dat, const std::vector<int >& dat_int,
61+ std::ostream* msgs = nullptr , double scaling_step_size = 1e-3 ,
62+ double function_tolerance = 1e-6 ,
63+ long int max_num_steps = 200 ) { // NOLINT(runtime/int)
64+ const auto & x_eval = x.eval ();
65+ algebra_solver_check (x_eval, y, dat, dat_int, function_tolerance,
66+ max_num_steps);
67+ check_nonnegative (" algebra_solver" , " scaling_step_size" , scaling_step_size);
68+
69+ check_matching_sizes (" algebra_solver" , " the algebraic system's output" ,
70+ value_of (f (x_eval, y, dat, dat_int, msgs)),
71+ " the vector of unknowns, x," , x);
72+
73+ return kinsol_solve (f, value_of (x_eval), y, dat, dat_int, 0 ,
74+ scaling_step_size, function_tolerance, max_num_steps);
20175}
20276
20377/* *
@@ -209,15 +83,19 @@ Eigen::Matrix<var, Eigen::Dynamic, 1> algebra_solver_newton_impl(
20983 * The user can also specify the scaled step size, the function
21084 * tolerance, and the maximum number of steps.
21185 *
212- * Signature to maintain backward compatibility, will be removed
213- * in the future.
86+ * Overload the previous definition to handle the case where y
87+ * is a vector of parameters (var). The overload calls the
88+ * algebraic solver defined above and builds a vari object on
89+ * top, using the algebra_solver_vari class.
21490 *
21591 * @tparam F type of equation system function.
21692 * @tparam T type of initial guess vector.
21793 *
21894 * @param[in] f Functor that evaluated the system of equations.
21995 * @param[in] x Vector of starting values.
220- * @param[in] y Parameter vector for the equation system.
96+ * @param[in] y Parameter vector for the equation system. The function
97+ * is overloaded to treat y as a vector of doubles or of a
98+ * a template type T.
22199 * @param[in] dat Continuous data vector for the equation system.
222100 * @param[in] dat_int Integer data vector for the equation system.
223101 * @param[in, out] msgs The print stream for warning messages.
@@ -241,16 +119,34 @@ Eigen::Matrix<var, Eigen::Dynamic, 1> algebra_solver_newton_impl(
241119 * @throw <code>std::domain_error if solver exceeds max_num_steps.
242120 */
243121template <typename F, typename T1, typename T2,
244- require_all_eigen_vector_t <T1, T2>* = nullptr >
122+ require_all_eigen_vector_t <T1, T2>* = nullptr ,
123+ require_st_var<T2>* = nullptr >
245124Eigen::Matrix<scalar_type_t <T2>, Eigen::Dynamic, 1 > algebra_solver_newton (
246125 const F& f, const T1& x, const T2& y, const std::vector<double >& dat,
247- const std::vector<int >& dat_int, std::ostream* const msgs = nullptr ,
248- const double scaling_step_size = 1e-3 ,
249- const double function_tolerance = 1e-6 ,
250- const long int max_num_steps = 200 ) { // NOLINT(runtime/int)
251- return algebra_solver_newton_impl (algebra_solver_adapter<F>(f), x, msgs,
252- scaling_step_size, function_tolerance,
253- max_num_steps, y, dat, dat_int);
126+ const std::vector<int >& dat_int, std::ostream* msgs = nullptr ,
127+ double scaling_step_size = 1e-3 , double function_tolerance = 1e-6 ,
128+ long int max_num_steps = 200 ) { // NOLINT(runtime/int)
129+
130+ const auto & x_eval = x.eval ();
131+ const auto & y_eval = y.eval ();
132+ Eigen::VectorXd theta_dbl = algebra_solver_newton (
133+ f, x_eval, value_of (y_eval), dat, dat_int, msgs, scaling_step_size,
134+ function_tolerance, max_num_steps);
135+
136+ typedef system_functor<F, double , double , false > Fy;
137+ typedef system_functor<F, double , double , true > Fs;
138+ typedef hybrj_functor_solver<Fs, F, double , double > Fx;
139+ Fx fx (Fs (), f, value_of (x_eval), value_of (y_eval), dat, dat_int, msgs);
140+
141+ // Construct vari
142+ auto * vi0 = new algebra_solver_vari<Fy, F, scalar_type_t <T2>, Fx>(
143+ Fy (), f, value_of (x_eval), y_eval, dat, dat_int, theta_dbl, fx, msgs);
144+ Eigen::Matrix<var, Eigen::Dynamic, 1 > theta (x.size ());
145+ theta (0 ) = var (vi0->theta_ [0 ]);
146+ for (int i = 1 ; i < x.size (); ++i)
147+ theta (i) = var (vi0->theta_ [i]);
148+
149+ return theta;
254150}
255151
256152} // namespace math
0 commit comments