Skip to content

Commit 74c5354

Browse files
authored
Merge pull request #2421 from jgaeb/feature/issue-2401-alg-solver-adjoint
Feature/issue 2401 alg solver adjoint
2 parents 02dc560 + 09ca12d commit 74c5354

13 files changed

+1191
-808
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#ifndef STAN_MATH_PRIM_FUNCTOR_ALGEBRA_SOLVER_ADAPTER_HPP
2+
#define STAN_MATH_PRIM_FUNCTOR_ALGEBRA_SOLVER_ADAPTER_HPP
3+
4+
#include <ostream>
5+
#include <vector>
6+
7+
/**
8+
* Adapt the non-variadic algebra_solver_newton and algebra_solver_powell
9+
* arguemts to the variadic algebra_solver_newton_impl and
10+
* algebra_solver_powell_impl interfaces.
11+
*
12+
* @tparam F type of function to adapt.
13+
*/
14+
template <typename F>
15+
struct algebra_solver_adapter {
16+
const F& f_;
17+
18+
explicit algebra_solver_adapter(const F& f) : f_(f) {}
19+
20+
template <typename T1, typename... T2>
21+
auto operator()(const T1& x, std::ostream* msgs, T2&&... args) const {
22+
return f_(x, args..., msgs);
23+
}
24+
};
25+
26+
#endif

stan/math/rev/core/chainable_object.hpp

+55
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,61 @@ auto make_chainable_ptr(T&& obj) {
6262
return &ptr->get();
6363
}
6464

65+
/**
66+
* `unsafe_chainable_object` hold another object and is useful for connecting
67+
* the lifetime of a specific object to the chainable stack. This class
68+
* differs from `chainable_object` in that this class does not evaluate
69+
* expressions.
70+
*
71+
* `unsafe_chainable_object` objects should only be allocated with `new`.
72+
* `unsafe_chainable_object` objects allocated on the stack will result
73+
* in a double free (`obj_` will get destructed once when the
74+
* `unsafe_chainable_object` leaves scope and once when the chainable
75+
* stack memory is recovered).
76+
*
77+
* @tparam T type of object to hold
78+
*/
79+
template <typename T>
80+
class unsafe_chainable_object : public chainable_alloc {
81+
private:
82+
std::decay_t<T> obj_;
83+
84+
public:
85+
/**
86+
* Construct chainable object from another object
87+
*
88+
* @tparam S type of object to hold (must have the same plain type as `T`)
89+
*/
90+
template <typename S,
91+
require_same_t<plain_type_t<T>, plain_type_t<S>>* = nullptr>
92+
explicit unsafe_chainable_object(S&& obj) : obj_(std::forward<S>(obj)) {}
93+
94+
/**
95+
* Return a reference to the underlying object
96+
*
97+
* @return reference to underlying object
98+
*/
99+
inline auto& get() noexcept { return obj_; }
100+
inline const auto& get() const noexcept { return obj_; }
101+
};
102+
103+
/**
104+
* Store the given object in a `chainable_object` so it is destructed
105+
* only when the chainable stack memory is recovered and return
106+
* a pointer to the underlying object This function
107+
* differs from `make_chainable_object` in that this class does not evaluate
108+
* expressions.
109+
*
110+
* @tparam T type of object to hold
111+
* @param obj object to hold
112+
* @return pointer to object held in `chainable_object`
113+
*/
114+
template <typename T>
115+
auto make_unsafe_chainable_ptr(T&& obj) {
116+
auto ptr = new unsafe_chainable_object<T>(std::forward<T>(obj));
117+
return &ptr->get();
118+
}
119+
65120
} // namespace math
66121
} // namespace stan
67122
#endif

stan/math/rev/functor/algebra_solver_newton.hpp

+167-63
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33

44
#include <stan/math/rev/core.hpp>
55
#include <stan/math/rev/functor/algebra_system.hpp>
6-
#include <stan/math/rev/functor/algebra_solver_powell.hpp>
76
#include <stan/math/rev/functor/kinsol_solve.hpp>
87
#include <stan/math/prim/err.hpp>
9-
#include <stan/math/prim/fun/mdivide_left.hpp>
108
#include <stan/math/prim/fun/value_of.hpp>
9+
#include <stan/math/prim/functor/algebra_solver_adapter.hpp>
1110
#include <unsupported/Eigen/NonLinearOptimization>
1211
#include <iostream>
1312
#include <string>
@@ -25,53 +24,180 @@ namespace math {
2524
* The user can also specify the scaled step size, the function
2625
* tolerance, and the maximum number of steps.
2726
*
27+
* This overload handles non-autodiff parameters.
28+
*
2829
* @tparam F type of equation system function.
2930
* @tparam T type of initial guess vector.
31+
* @tparam Args types of additional parameters to the equation system functor
3032
*
3133
* @param[in] f Functor that evaluated the system of equations.
3234
* @param[in] x Vector of starting values.
33-
* @param[in] y Parameter vector for the equation system. The function
34-
* is overloaded to treat y as a vector of doubles or of a
35-
* a template type T.
36-
* @param[in] dat Continuous data vector for the equation system.
37-
* @param[in] dat_int Integer data vector for the equation system.
3835
* @param[in, out] msgs The print stream for warning messages.
3936
* @param[in] scaling_step_size Scaled-step stopping tolerance. If
4037
* a Newton step is smaller than the scaling step
4138
* tolerance, the code breaks, assuming the solver is no
4239
* longer making significant progress (i.e. is stuck)
4340
* @param[in] function_tolerance determines whether roots are acceptable.
4441
* @param[in] max_num_steps maximum number of function evaluations.
45-
* * @throw <code>std::invalid_argument</code> if x has size zero.
42+
* @param[in] args Additional parameters to the equation system functor.
43+
* @return theta Vector of solutions to the system of equations.
44+
* @pre f returns finite values when passed any value of x and the given args.
45+
* @throw <code>std::invalid_argument</code> if x has size zero.
46+
* @throw <code>std::invalid_argument</code> if x has non-finite elements.
47+
* @throw <code>std::invalid_argument</code> if scaled_step_size is strictly
48+
* negative.
49+
* @throw <code>std::invalid_argument</code> if function_tolerance is strictly
50+
* negative.
51+
* @throw <code>std::invalid_argument</code> if max_num_steps is not positive.
52+
* @throw <code>std::domain_error if solver exceeds max_num_steps.
53+
*/
54+
template <typename F, typename T, typename... Args,
55+
require_eigen_vector_t<T>* = nullptr,
56+
require_all_st_arithmetic<Args...>* = nullptr>
57+
Eigen::VectorXd algebra_solver_newton_impl(const F& f, const T& x,
58+
std::ostream* const msgs,
59+
const double scaling_step_size,
60+
const double function_tolerance,
61+
const int64_t max_num_steps,
62+
const Args&... args) {
63+
const auto& x_ref = to_ref(value_of(x));
64+
65+
check_nonzero_size("algebra_solver_newton", "initial guess", x_ref);
66+
check_finite("algebra_solver_newton", "initial guess", x_ref);
67+
check_nonnegative("algebra_solver_newton", "scaling_step_size",
68+
scaling_step_size);
69+
check_nonnegative("algebra_solver_newton", "function_tolerance",
70+
function_tolerance);
71+
check_positive("algebra_solver_newton", "max_num_steps", max_num_steps);
72+
73+
return kinsol_solve(f, x_ref, scaling_step_size, function_tolerance,
74+
max_num_steps, 1, 10, KIN_LINESEARCH, msgs, args...);
75+
}
76+
77+
/**
78+
* Return the solution to the specified system of algebraic
79+
* equations given an initial guess, and parameters and data,
80+
* which get passed into the algebraic system. Use the
81+
* KINSOL solver from the SUNDIALS suite.
82+
*
83+
* The user can also specify the scaled step size, the function
84+
* tolerance, and the maximum number of steps.
85+
*
86+
* This overload handles var parameters.
87+
*
88+
* The Jacobian \(J_{xy}\) (i.e., Jacobian of unknown \(x\) w.r.t. the parameter
89+
* \(y\)) is calculated given the solution as follows. Since
90+
* \[
91+
* f(x, y) = 0,
92+
* \]
93+
* we have (\(J_{pq}\) being the Jacobian matrix \(\tfrac {dq} {dq}\))
94+
* \[
95+
* - J_{fx} J_{xy} = J_{fy},
96+
* \]
97+
* and therefore \(J_{xy}\) can be solved from system
98+
* \[
99+
* - J_{fx} J_{xy} = J_{fy}.
100+
* \]
101+
* Let \(eta\) be the adjoint with respect to \(x\); then to calculate
102+
* \[
103+
* \eta J_{xy},
104+
* \]
105+
* we solve
106+
* \[
107+
* - (\eta J_{fx}^{-1}) J_{fy}.
108+
* \]
109+
*
110+
* @tparam F type of equation system function.
111+
* @tparam T type of initial guess vector.
112+
* @tparam Args types of additional parameters to the equation system functor
113+
*
114+
* @param[in] f Functor that evaluated the system of equations.
115+
* @param[in] x Vector of starting values.
116+
* @param[in, out] msgs The print stream for warning messages.
117+
* @param[in] scaling_step_size Scaled-step stopping tolerance. If
118+
* a Newton step is smaller than the scaling step
119+
* tolerance, the code breaks, assuming the solver is no
120+
* longer making significant progress (i.e. is stuck)
121+
* @param[in] function_tolerance determines whether roots are acceptable.
122+
* @param[in] max_num_steps maximum number of function evaluations.
123+
* @param[in] args Additional parameters to the equation system functor.
124+
* @return theta Vector of solutions to the system of equations.
125+
* @pre f returns finite values when passed any value of x and the given args.
126+
* @throw <code>std::invalid_argument</code> if x has size zero.
46127
* @throw <code>std::invalid_argument</code> if x has non-finite elements.
47-
* @throw <code>std::invalid_argument</code> if y has non-finite elements.
48-
* @throw <code>std::invalid_argument</code> if dat has non-finite elements.
49-
* @throw <code>std::invalid_argument</code> if dat_int has non-finite elements.
50128
* @throw <code>std::invalid_argument</code> if scaled_step_size is strictly
51129
* negative.
52130
* @throw <code>std::invalid_argument</code> if function_tolerance is strictly
53131
* negative.
54132
* @throw <code>std::invalid_argument</code> if max_num_steps is not positive.
55-
* @throw <code>std::domain_error</code> if solver exceeds max_num_steps.
133+
* @throw <code>std::domain_error if solver exceeds max_num_steps.
56134
*/
57-
template <typename F, typename T, require_eigen_vector_t<T>* = nullptr>
58-
Eigen::VectorXd algebra_solver_newton(
59-
const F& f, const T& x, const Eigen::VectorXd& y,
60-
const std::vector<double>& dat, const std::vector<int>& dat_int,
61-
std::ostream* msgs = nullptr, double scaling_step_size = 1e-3,
62-
double function_tolerance = 1e-6,
63-
long int max_num_steps = 200) { // NOLINT(runtime/int)
64-
const auto& x_eval = x.eval();
65-
algebra_solver_check(x_eval, y, dat, dat_int, function_tolerance,
66-
max_num_steps);
67-
check_nonnegative("algebra_solver", "scaling_step_size", scaling_step_size);
68-
69-
check_matching_sizes("algebra_solver", "the algebraic system's output",
70-
value_of(f(x_eval, y, dat, dat_int, msgs)),
71-
"the vector of unknowns, x,", x);
72-
73-
return kinsol_solve(f, value_of(x_eval), y, dat, dat_int, 0,
74-
scaling_step_size, function_tolerance, max_num_steps);
135+
template <typename F, typename T, typename... T_Args,
136+
require_eigen_vector_t<T>* = nullptr,
137+
require_any_st_var<T_Args...>* = nullptr>
138+
Eigen::Matrix<var, Eigen::Dynamic, 1> algebra_solver_newton_impl(
139+
const F& f, const T& x, std::ostream* const msgs,
140+
const double scaling_step_size, const double function_tolerance,
141+
const int64_t max_num_steps, const T_Args&... args) {
142+
const auto& x_ref = to_ref(value_of(x));
143+
auto arena_args_tuple = std::make_tuple(to_arena(args)...);
144+
auto args_vals_tuple = apply(
145+
[&](const auto&... args) {
146+
return std::make_tuple(to_ref(value_of(args))...);
147+
},
148+
arena_args_tuple);
149+
150+
check_nonzero_size("algebra_solver_newton", "initial guess", x_ref);
151+
check_finite("algebra_solver_newton", "initial guess", x_ref);
152+
check_nonnegative("algebra_solver_newton", "scaling_step_size",
153+
scaling_step_size);
154+
check_nonnegative("algebra_solver_newton", "function_tolerance",
155+
function_tolerance);
156+
check_positive("algebra_solver_newton", "max_num_steps", max_num_steps);
157+
158+
// Solve the system
159+
Eigen::VectorXd theta_dbl = apply(
160+
[&](const auto&... vals) {
161+
return kinsol_solve(f, x_ref, scaling_step_size, function_tolerance,
162+
max_num_steps, 1, 10, KIN_LINESEARCH, msgs,
163+
vals...);
164+
},
165+
args_vals_tuple);
166+
167+
auto f_wrt_x = [&](const auto& x) {
168+
return apply([&](const auto&... args) { return f(x, msgs, args...); },
169+
args_vals_tuple);
170+
};
171+
172+
Eigen::MatrixXd Jf_x;
173+
Eigen::VectorXd f_x;
174+
175+
jacobian(f_wrt_x, theta_dbl, f_x, Jf_x);
176+
177+
using ret_type = Eigen::Matrix<var, Eigen::Dynamic, -1>;
178+
arena_t<ret_type> ret = theta_dbl;
179+
auto Jf_x_T_lu_ptr
180+
= make_unsafe_chainable_ptr(Jf_x.transpose().partialPivLu()); // Lu
181+
182+
reverse_pass_callback([f, ret, arena_args_tuple, Jf_x_T_lu_ptr,
183+
msgs]() mutable {
184+
Eigen::VectorXd eta = -Jf_x_T_lu_ptr->solve(ret.adj().eval());
185+
186+
// Contract with Jacobian of f with respect to y using a nested reverse
187+
// autodiff pass.
188+
{
189+
nested_rev_autodiff rev;
190+
191+
Eigen::VectorXd ret_val = ret.val();
192+
auto x_nrad_ = apply(
193+
[&](const auto&... args) { return eval(f(ret_val, msgs, args...)); },
194+
arena_args_tuple);
195+
x_nrad_.adj() = eta;
196+
grad();
197+
}
198+
});
199+
200+
return ret_type(ret);
75201
}
76202

77203
/**
@@ -83,19 +209,15 @@ Eigen::VectorXd algebra_solver_newton(
83209
* The user can also specify the scaled step size, the function
84210
* tolerance, and the maximum number of steps.
85211
*
86-
* Overload the previous definition to handle the case where y
87-
* is a vector of parameters (var). The overload calls the
88-
* algebraic solver defined above and builds a vari object on
89-
* top, using the algebra_solver_vari class.
212+
* Signature to maintain backward compatibility, will be removed
213+
* in the future.
90214
*
91215
* @tparam F type of equation system function.
92216
* @tparam T type of initial guess vector.
93217
*
94218
* @param[in] f Functor that evaluated the system of equations.
95219
* @param[in] x Vector of starting values.
96-
* @param[in] y Parameter vector for the equation system. The function
97-
* is overloaded to treat y as a vector of doubles or of a
98-
* a template type T.
220+
* @param[in] y Parameter vector for the equation system.
99221
* @param[in] dat Continuous data vector for the equation system.
100222
* @param[in] dat_int Integer data vector for the equation system.
101223
* @param[in, out] msgs The print stream for warning messages.
@@ -119,34 +241,16 @@ Eigen::VectorXd algebra_solver_newton(
119241
* @throw <code>std::domain_error if solver exceeds max_num_steps.
120242
*/
121243
template <typename F, typename T1, typename T2,
122-
require_all_eigen_vector_t<T1, T2>* = nullptr,
123-
require_st_var<T2>* = nullptr>
244+
require_all_eigen_vector_t<T1, T2>* = nullptr>
124245
Eigen::Matrix<scalar_type_t<T2>, Eigen::Dynamic, 1> algebra_solver_newton(
125246
const F& f, const T1& x, const T2& y, const std::vector<double>& dat,
126-
const std::vector<int>& dat_int, std::ostream* msgs = nullptr,
127-
double scaling_step_size = 1e-3, double function_tolerance = 1e-6,
128-
long int max_num_steps = 200) { // NOLINT(runtime/int)
129-
130-
const auto& x_eval = x.eval();
131-
const auto& y_eval = y.eval();
132-
Eigen::VectorXd theta_dbl = algebra_solver_newton(
133-
f, x_eval, value_of(y_eval), dat, dat_int, msgs, scaling_step_size,
134-
function_tolerance, max_num_steps);
135-
136-
typedef system_functor<F, double, double, false> Fy;
137-
typedef system_functor<F, double, double, true> Fs;
138-
typedef hybrj_functor_solver<Fs, F, double, double> Fx;
139-
Fx fx(Fs(), f, value_of(x_eval), value_of(y_eval), dat, dat_int, msgs);
140-
141-
// Construct vari
142-
auto* vi0 = new algebra_solver_vari<Fy, F, scalar_type_t<T2>, Fx>(
143-
Fy(), f, value_of(x_eval), y_eval, dat, dat_int, theta_dbl, fx, msgs);
144-
Eigen::Matrix<var, Eigen::Dynamic, 1> theta(x.size());
145-
theta(0) = var(vi0->theta_[0]);
146-
for (int i = 1; i < x.size(); ++i)
147-
theta(i) = var(vi0->theta_[i]);
148-
149-
return theta;
247+
const std::vector<int>& dat_int, std::ostream* const msgs = nullptr,
248+
const double scaling_step_size = 1e-3,
249+
const double function_tolerance = 1e-6,
250+
const long int max_num_steps = 200) { // NOLINT(runtime/int)
251+
return algebra_solver_newton_impl(algebra_solver_adapter<F>(f), x, msgs,
252+
scaling_step_size, function_tolerance,
253+
max_num_steps, y, dat, dat_int);
150254
}
151255

152256
} // namespace math

0 commit comments

Comments
 (0)