Skip to content

Commit faba87c

Browse files
authored
Merge pull request #2563 from stan-dev/revert-2421-feature/issue-2401-alg-solver-adjoint
Revert "Feature/issue 2401 alg solver adjoint"
2 parents 7dd7b31 + bfc7987 commit faba87c

13 files changed

+808
-1191
lines changed

stan/math/prim/functor/algebra_solver_adapter.hpp

-26
This file was deleted.

stan/math/rev/core/chainable_object.hpp

-55
Original file line numberDiff line numberDiff line change
@@ -62,61 +62,6 @@ auto make_chainable_ptr(T&& obj) {
6262
return &ptr->get();
6363
}
6464

65-
/**
66-
* `unsafe_chainable_object` hold another object and is useful for connecting
67-
* the lifetime of a specific object to the chainable stack. This class
68-
* differs from `chainable_object` in that this class does not evaluate
69-
* expressions.
70-
*
71-
* `unsafe_chainable_object` objects should only be allocated with `new`.
72-
* `unsafe_chainable_object` objects allocated on the stack will result
73-
* in a double free (`obj_` will get destructed once when the
74-
* `unsafe_chainable_object` leaves scope and once when the chainable
75-
* stack memory is recovered).
76-
*
77-
* @tparam T type of object to hold
78-
*/
79-
template <typename T>
80-
class unsafe_chainable_object : public chainable_alloc {
81-
private:
82-
std::decay_t<T> obj_;
83-
84-
public:
85-
/**
86-
* Construct chainable object from another object
87-
*
88-
* @tparam S type of object to hold (must have the same plain type as `T`)
89-
*/
90-
template <typename S,
91-
require_same_t<plain_type_t<T>, plain_type_t<S>>* = nullptr>
92-
explicit unsafe_chainable_object(S&& obj) : obj_(std::forward<S>(obj)) {}
93-
94-
/**
95-
* Return a reference to the underlying object
96-
*
97-
* @return reference to underlying object
98-
*/
99-
inline auto& get() noexcept { return obj_; }
100-
inline const auto& get() const noexcept { return obj_; }
101-
};
102-
103-
/**
104-
* Store the given object in a `chainable_object` so it is destructed
105-
* only when the chainable stack memory is recovered and return
106-
* a pointer to the underlying object This function
107-
* differs from `make_chainable_object` in that this class does not evaluate
108-
* expressions.
109-
*
110-
* @tparam T type of object to hold
111-
* @param obj object to hold
112-
* @return pointer to object held in `chainable_object`
113-
*/
114-
template <typename T>
115-
auto make_unsafe_chainable_ptr(T&& obj) {
116-
auto ptr = new unsafe_chainable_object<T>(std::forward<T>(obj));
117-
return &ptr->get();
118-
}
119-
12065
} // namespace math
12166
} // namespace stan
12267
#endif

stan/math/rev/functor/algebra_solver_newton.hpp

+63-167
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33

44
#include <stan/math/rev/core.hpp>
55
#include <stan/math/rev/functor/algebra_system.hpp>
6+
#include <stan/math/rev/functor/algebra_solver_powell.hpp>
67
#include <stan/math/rev/functor/kinsol_solve.hpp>
78
#include <stan/math/prim/err.hpp>
9+
#include <stan/math/prim/fun/mdivide_left.hpp>
810
#include <stan/math/prim/fun/value_of.hpp>
9-
#include <stan/math/prim/functor/algebra_solver_adapter.hpp>
1011
#include <unsupported/Eigen/NonLinearOptimization>
1112
#include <iostream>
1213
#include <string>
@@ -24,180 +25,53 @@ namespace math {
2425
* The user can also specify the scaled step size, the function
2526
* tolerance, and the maximum number of steps.
2627
*
27-
* This overload handles non-autodiff parameters.
28-
*
29-
* @tparam F type of equation system function.
30-
* @tparam T type of initial guess vector.
31-
* @tparam Args types of additional parameters to the equation system functor
32-
*
33-
* @param[in] f Functor that evaluated the system of equations.
34-
* @param[in] x Vector of starting values.
35-
* @param[in, out] msgs The print stream for warning messages.
36-
* @param[in] scaling_step_size Scaled-step stopping tolerance. If
37-
* a Newton step is smaller than the scaling step
38-
* tolerance, the code breaks, assuming the solver is no
39-
* longer making significant progress (i.e. is stuck)
40-
* @param[in] function_tolerance determines whether roots are acceptable.
41-
* @param[in] max_num_steps maximum number of function evaluations.
42-
* @param[in] args Additional parameters to the equation system functor.
43-
* @return theta Vector of solutions to the system of equations.
44-
* @pre f returns finite values when passed any value of x and the given args.
45-
* @throw <code>std::invalid_argument</code> if x has size zero.
46-
* @throw <code>std::invalid_argument</code> if x has non-finite elements.
47-
* @throw <code>std::invalid_argument</code> if scaled_step_size is strictly
48-
* negative.
49-
* @throw <code>std::invalid_argument</code> if function_tolerance is strictly
50-
* negative.
51-
* @throw <code>std::invalid_argument</code> if max_num_steps is not positive.
52-
* @throw <code>std::domain_error if solver exceeds max_num_steps.
53-
*/
54-
template <typename F, typename T, typename... Args,
55-
require_eigen_vector_t<T>* = nullptr,
56-
require_all_st_arithmetic<Args...>* = nullptr>
57-
Eigen::VectorXd algebra_solver_newton_impl(const F& f, const T& x,
58-
std::ostream* const msgs,
59-
const double scaling_step_size,
60-
const double function_tolerance,
61-
const int64_t max_num_steps,
62-
const Args&... args) {
63-
const auto& x_ref = to_ref(value_of(x));
64-
65-
check_nonzero_size("algebra_solver_newton", "initial guess", x_ref);
66-
check_finite("algebra_solver_newton", "initial guess", x_ref);
67-
check_nonnegative("algebra_solver_newton", "scaling_step_size",
68-
scaling_step_size);
69-
check_nonnegative("algebra_solver_newton", "function_tolerance",
70-
function_tolerance);
71-
check_positive("algebra_solver_newton", "max_num_steps", max_num_steps);
72-
73-
return kinsol_solve(f, x_ref, scaling_step_size, function_tolerance,
74-
max_num_steps, 1, 10, KIN_LINESEARCH, msgs, args...);
75-
}
76-
77-
/**
78-
* Return the solution to the specified system of algebraic
79-
* equations given an initial guess, and parameters and data,
80-
* which get passed into the algebraic system. Use the
81-
* KINSOL solver from the SUNDIALS suite.
82-
*
83-
* The user can also specify the scaled step size, the function
84-
* tolerance, and the maximum number of steps.
85-
*
86-
* This overload handles var parameters.
87-
*
88-
* The Jacobian \(J_{xy}\) (i.e., Jacobian of unknown \(x\) w.r.t. the parameter
89-
* \(y\)) is calculated given the solution as follows. Since
90-
* \[
91-
* f(x, y) = 0,
92-
* \]
93-
* we have (\(J_{pq}\) being the Jacobian matrix \(\tfrac {dq} {dq}\))
94-
* \[
95-
* - J_{fx} J_{xy} = J_{fy},
96-
* \]
97-
* and therefore \(J_{xy}\) can be solved from system
98-
* \[
99-
* - J_{fx} J_{xy} = J_{fy}.
100-
* \]
101-
* Let \(eta\) be the adjoint with respect to \(x\); then to calculate
102-
* \[
103-
* \eta J_{xy},
104-
* \]
105-
* we solve
106-
* \[
107-
* - (\eta J_{fx}^{-1}) J_{fy}.
108-
* \]
109-
*
11028
* @tparam F type of equation system function.
11129
* @tparam T type of initial guess vector.
112-
* @tparam Args types of additional parameters to the equation system functor
11330
*
11431
* @param[in] f Functor that evaluated the system of equations.
11532
* @param[in] x Vector of starting values.
33+
* @param[in] y Parameter vector for the equation system. The function
34+
* is overloaded to treat y as a vector of doubles or of a
35+
* a template type T.
36+
* @param[in] dat Continuous data vector for the equation system.
37+
* @param[in] dat_int Integer data vector for the equation system.
11638
* @param[in, out] msgs The print stream for warning messages.
11739
* @param[in] scaling_step_size Scaled-step stopping tolerance. If
11840
* a Newton step is smaller than the scaling step
11941
* tolerance, the code breaks, assuming the solver is no
12042
* longer making significant progress (i.e. is stuck)
12143
* @param[in] function_tolerance determines whether roots are acceptable.
12244
* @param[in] max_num_steps maximum number of function evaluations.
123-
* @param[in] args Additional parameters to the equation system functor.
124-
* @return theta Vector of solutions to the system of equations.
125-
* @pre f returns finite values when passed any value of x and the given args.
126-
* @throw <code>std::invalid_argument</code> if x has size zero.
45+
* * @throw <code>std::invalid_argument</code> if x has size zero.
12746
* @throw <code>std::invalid_argument</code> if x has non-finite elements.
47+
* @throw <code>std::invalid_argument</code> if y has non-finite elements.
48+
* @throw <code>std::invalid_argument</code> if dat has non-finite elements.
49+
* @throw <code>std::invalid_argument</code> if dat_int has non-finite elements.
12850
* @throw <code>std::invalid_argument</code> if scaled_step_size is strictly
12951
* negative.
13052
* @throw <code>std::invalid_argument</code> if function_tolerance is strictly
13153
* negative.
13254
* @throw <code>std::invalid_argument</code> if max_num_steps is not positive.
133-
* @throw <code>std::domain_error if solver exceeds max_num_steps.
55+
* @throw <code>std::domain_error</code> if solver exceeds max_num_steps.
13456
*/
135-
template <typename F, typename T, typename... T_Args,
136-
require_eigen_vector_t<T>* = nullptr,
137-
require_any_st_var<T_Args...>* = nullptr>
138-
Eigen::Matrix<var, Eigen::Dynamic, 1> algebra_solver_newton_impl(
139-
const F& f, const T& x, std::ostream* const msgs,
140-
const double scaling_step_size, const double function_tolerance,
141-
const int64_t max_num_steps, const T_Args&... args) {
142-
const auto& x_ref = to_ref(value_of(x));
143-
auto arena_args_tuple = std::make_tuple(to_arena(args)...);
144-
auto args_vals_tuple = apply(
145-
[&](const auto&... args) {
146-
return std::make_tuple(to_ref(value_of(args))...);
147-
},
148-
arena_args_tuple);
149-
150-
check_nonzero_size("algebra_solver_newton", "initial guess", x_ref);
151-
check_finite("algebra_solver_newton", "initial guess", x_ref);
152-
check_nonnegative("algebra_solver_newton", "scaling_step_size",
153-
scaling_step_size);
154-
check_nonnegative("algebra_solver_newton", "function_tolerance",
155-
function_tolerance);
156-
check_positive("algebra_solver_newton", "max_num_steps", max_num_steps);
157-
158-
// Solve the system
159-
Eigen::VectorXd theta_dbl = apply(
160-
[&](const auto&... vals) {
161-
return kinsol_solve(f, x_ref, scaling_step_size, function_tolerance,
162-
max_num_steps, 1, 10, KIN_LINESEARCH, msgs,
163-
vals...);
164-
},
165-
args_vals_tuple);
166-
167-
auto f_wrt_x = [&](const auto& x) {
168-
return apply([&](const auto&... args) { return f(x, msgs, args...); },
169-
args_vals_tuple);
170-
};
171-
172-
Eigen::MatrixXd Jf_x;
173-
Eigen::VectorXd f_x;
174-
175-
jacobian(f_wrt_x, theta_dbl, f_x, Jf_x);
176-
177-
using ret_type = Eigen::Matrix<var, Eigen::Dynamic, -1>;
178-
arena_t<ret_type> ret = theta_dbl;
179-
auto Jf_x_T_lu_ptr
180-
= make_unsafe_chainable_ptr(Jf_x.transpose().partialPivLu()); // Lu
181-
182-
reverse_pass_callback([f, ret, arena_args_tuple, Jf_x_T_lu_ptr,
183-
msgs]() mutable {
184-
Eigen::VectorXd eta = -Jf_x_T_lu_ptr->solve(ret.adj().eval());
185-
186-
// Contract with Jacobian of f with respect to y using a nested reverse
187-
// autodiff pass.
188-
{
189-
nested_rev_autodiff rev;
190-
191-
Eigen::VectorXd ret_val = ret.val();
192-
auto x_nrad_ = apply(
193-
[&](const auto&... args) { return eval(f(ret_val, msgs, args...)); },
194-
arena_args_tuple);
195-
x_nrad_.adj() = eta;
196-
grad();
197-
}
198-
});
199-
200-
return ret_type(ret);
57+
template <typename F, typename T, require_eigen_vector_t<T>* = nullptr>
58+
Eigen::VectorXd algebra_solver_newton(
59+
const F& f, const T& x, const Eigen::VectorXd& y,
60+
const std::vector<double>& dat, const std::vector<int>& dat_int,
61+
std::ostream* msgs = nullptr, double scaling_step_size = 1e-3,
62+
double function_tolerance = 1e-6,
63+
long int max_num_steps = 200) { // NOLINT(runtime/int)
64+
const auto& x_eval = x.eval();
65+
algebra_solver_check(x_eval, y, dat, dat_int, function_tolerance,
66+
max_num_steps);
67+
check_nonnegative("algebra_solver", "scaling_step_size", scaling_step_size);
68+
69+
check_matching_sizes("algebra_solver", "the algebraic system's output",
70+
value_of(f(x_eval, y, dat, dat_int, msgs)),
71+
"the vector of unknowns, x,", x);
72+
73+
return kinsol_solve(f, value_of(x_eval), y, dat, dat_int, 0,
74+
scaling_step_size, function_tolerance, max_num_steps);
20175
}
20276

20377
/**
@@ -209,15 +83,19 @@ Eigen::Matrix<var, Eigen::Dynamic, 1> algebra_solver_newton_impl(
20983
* The user can also specify the scaled step size, the function
21084
* tolerance, and the maximum number of steps.
21185
*
212-
* Signature to maintain backward compatibility, will be removed
213-
* in the future.
86+
* Overload the previous definition to handle the case where y
87+
* is a vector of parameters (var). The overload calls the
88+
* algebraic solver defined above and builds a vari object on
89+
* top, using the algebra_solver_vari class.
21490
*
21591
* @tparam F type of equation system function.
21692
* @tparam T type of initial guess vector.
21793
*
21894
* @param[in] f Functor that evaluated the system of equations.
21995
* @param[in] x Vector of starting values.
220-
* @param[in] y Parameter vector for the equation system.
96+
* @param[in] y Parameter vector for the equation system. The function
97+
* is overloaded to treat y as a vector of doubles or of a
98+
* a template type T.
22199
* @param[in] dat Continuous data vector for the equation system.
222100
* @param[in] dat_int Integer data vector for the equation system.
223101
* @param[in, out] msgs The print stream for warning messages.
@@ -241,16 +119,34 @@ Eigen::Matrix<var, Eigen::Dynamic, 1> algebra_solver_newton_impl(
241119
* @throw <code>std::domain_error if solver exceeds max_num_steps.
242120
*/
243121
template <typename F, typename T1, typename T2,
244-
require_all_eigen_vector_t<T1, T2>* = nullptr>
122+
require_all_eigen_vector_t<T1, T2>* = nullptr,
123+
require_st_var<T2>* = nullptr>
245124
Eigen::Matrix<scalar_type_t<T2>, Eigen::Dynamic, 1> algebra_solver_newton(
246125
const F& f, const T1& x, const T2& y, const std::vector<double>& dat,
247-
const std::vector<int>& dat_int, std::ostream* const msgs = nullptr,
248-
const double scaling_step_size = 1e-3,
249-
const double function_tolerance = 1e-6,
250-
const long int max_num_steps = 200) { // NOLINT(runtime/int)
251-
return algebra_solver_newton_impl(algebra_solver_adapter<F>(f), x, msgs,
252-
scaling_step_size, function_tolerance,
253-
max_num_steps, y, dat, dat_int);
126+
const std::vector<int>& dat_int, std::ostream* msgs = nullptr,
127+
double scaling_step_size = 1e-3, double function_tolerance = 1e-6,
128+
long int max_num_steps = 200) { // NOLINT(runtime/int)
129+
130+
const auto& x_eval = x.eval();
131+
const auto& y_eval = y.eval();
132+
Eigen::VectorXd theta_dbl = algebra_solver_newton(
133+
f, x_eval, value_of(y_eval), dat, dat_int, msgs, scaling_step_size,
134+
function_tolerance, max_num_steps);
135+
136+
typedef system_functor<F, double, double, false> Fy;
137+
typedef system_functor<F, double, double, true> Fs;
138+
typedef hybrj_functor_solver<Fs, F, double, double> Fx;
139+
Fx fx(Fs(), f, value_of(x_eval), value_of(y_eval), dat, dat_int, msgs);
140+
141+
// Construct vari
142+
auto* vi0 = new algebra_solver_vari<Fy, F, scalar_type_t<T2>, Fx>(
143+
Fy(), f, value_of(x_eval), y_eval, dat, dat_int, theta_dbl, fx, msgs);
144+
Eigen::Matrix<var, Eigen::Dynamic, 1> theta(x.size());
145+
theta(0) = var(vi0->theta_[0]);
146+
for (int i = 1; i < x.size(); ++i)
147+
theta(i) = var(vi0->theta_[i]);
148+
149+
return theta;
254150
}
255151

256152
} // namespace math

0 commit comments

Comments
 (0)