1212#include < stan/math/prim/functor/iter_tuple_nested.hpp>
1313#include < unsupported/Eigen/MatrixFunctions>
1414#include < cmath>
15+ #include < optional>
1516
1617/* *
1718 * @file
@@ -26,7 +27,7 @@ namespace math {
2627/* *
2728 * Options for the laplace sampler
2829 */
29- struct laplace_options {
30+ struct laplace_options_base {
3031 /* Size of the blocks in block diagonal hessian*/
3132 int hessian_block_size{1 };
3233 /* *
@@ -46,6 +47,20 @@ struct laplace_options {
4647 int max_num_steps{100 };
4748};
4849
50+ template <bool HasInitTheta>
51+ struct laplace_options ;
52+
53+ template <>
54+ struct laplace_options <false > : public laplace_options_base {};
55+
56+ template <>
57+ struct laplace_options <true > : public laplace_options_base {
58+ /* Value for user supplied initial theta */
59+ Eigen::VectorXd theta_0{0 };
60+ };
61+
62+ using laplace_options_default = laplace_options<false >;
63+ using laplace_options_user_supplied = laplace_options<true >;
4964namespace internal {
5065
5166template <typename Covar, typename ThetaVec, typename WR, typename L_t,
@@ -448,37 +463,46 @@ inline STAN_COLD_PATH void throw_nan(NameStr&& name_str, ParamStr&& param_str,
448463 *
449464 */
450465template <typename LLFun, typename LLTupleArgs, typename CovarFun,
451- typename ThetaVec, typename CovarArgs,
452- require_t <is_all_arithmetic_scalar<ThetaVec, CovarArgs>>* = nullptr ,
453- require_eigen_vector_t <ThetaVec>* = nullptr >
454- inline auto laplace_marginal_density_est (LLFun&& ll_fun, LLTupleArgs&& ll_args,
455- ThetaVec&& theta_0,
456- CovarFun&& covariance_function,
457- CovarArgs&& covar_args,
458- const laplace_options& options,
459- std::ostream* msgs) {
466+ typename CovarArgs, bool InitTheta,
467+ require_t <is_all_arithmetic_scalar<CovarArgs>>* = nullptr >
468+ inline auto laplace_marginal_density_est (
469+ LLFun&& ll_fun, LLTupleArgs&& ll_args, CovarFun&& covariance_function,
470+ CovarArgs&& covar_args, const laplace_options<InitTheta>& options,
471+ std::ostream* msgs) {
460472 using Eigen::MatrixXd;
461473 using Eigen::SparseMatrix;
462474 using Eigen::VectorXd;
463- check_nonzero_size (" laplace_marginal" , " initial guess" , theta_0);
464- check_finite (" laplace_marginal" , " initial guess" , theta_0);
475+ if constexpr (InitTheta) {
476+ check_nonzero_size (" laplace_marginal" , " initial guess" , options.theta_0 );
477+ check_finite (" laplace_marginal" , " initial guess" , options.theta_0 );
478+ }
465479 check_nonnegative (" laplace_marginal" , " tolerance" , options.tolerance );
466480 check_positive (" laplace_marginal" , " max_num_steps" , options.max_num_steps );
467481 check_positive (" laplace_marginal" , " hessian_block_size" ,
468482 options.hessian_block_size );
469483 check_nonnegative (" laplace_marginal" , " max_steps_line_search" ,
470484 options.max_steps_line_search );
471- if (unlikely (theta_0.size () % options.hessian_block_size != 0 )) {
485+
486+ Eigen::MatrixXd covariance = stan::math::apply (
487+ [msgs, &covariance_function](auto &&... args) {
488+ return covariance_function (args..., msgs);
489+ },
490+ covar_args);
491+ check_square (" laplace_marginal" , " covariance" , covariance);
492+
493+ const Eigen::Index theta_size = covariance.rows ();
494+
495+ if (unlikely (theta_size % options.hessian_block_size != 0 )) {
472496 [&]() STAN_COLD_PATH {
473497 std::stringstream msg;
474- msg << " laplace_marginal_density: The hessian size (" << theta_0. size ()
475- << " , " << theta_0. size ()
498+ msg << " laplace_marginal_density: The hessian size (" << theta_size
499+ << " , " << theta_size
476500 << " ) is not divisible by the hessian block size ("
477501 << options.hessian_block_size
478502 << " )"
479503 " . Try a hessian block size such as [1, " ;
480504 for (int i = 2 ; i < 12 ; ++i) {
481- if (theta_0. size () % i == 0 ) {
505+ if (theta_size % i == 0 ) {
482506 msg << i << " , " ;
483507 }
484508 }
@@ -488,19 +512,20 @@ inline auto laplace_marginal_density_est(LLFun&& ll_fun, LLTupleArgs&& ll_args,
488512 throw std::domain_error (msg.str ());
489513 }();
490514 }
491- Eigen::MatrixXd covariance = stan::math::apply (
492- [msgs, &covariance_function](auto &&... args) {
493- return covariance_function (args..., msgs);
494- },
495- covar_args);
515+
496516 auto throw_overstep = [](const auto max_num_steps) STAN_COLD_PATH {
497517 throw std::domain_error (
498518 std::string (" laplace_marginal_density: max number of iterations: " )
499519 + std::to_string (max_num_steps) + " exceeded." );
500520 };
501521 auto ll_args_vals = value_of (ll_args);
502- const Eigen::Index theta_size = theta_0.size ();
503- Eigen::VectorXd theta = std::forward<ThetaVec>(theta_0);
522+ Eigen::VectorXd theta = [theta_size, &options]() {
523+ if constexpr (InitTheta) {
524+ return options.theta_0 ;
525+ } else {
526+ return Eigen::VectorXd::Zero (theta_size);
527+ }
528+ }();
504529 double objective_old = std::numeric_limits<double >::lowest ();
505530 double objective_new = std::numeric_limits<double >::lowest () + 1 ;
506531 Eigen::VectorXd a_prev = Eigen::VectorXd::Zero (theta_size);
@@ -572,7 +597,7 @@ inline auto laplace_marginal_density_est(LLFun&& ll_fun, LLTupleArgs&& ll_args,
572597 }
573598 }
574599 } else {
575- Eigen::SparseMatrix<double > W_r (theta. rows (), theta. rows () );
600+ Eigen::SparseMatrix<double > W_r (theta_size, theta_size );
576601 Eigen::Index block_size = options.hessian_block_size ;
577602 W_r.reserve (Eigen::VectorXi::Constant (W_r.cols (), block_size));
578603 const Eigen::Index n_block = W_r.cols () / block_size;
@@ -768,20 +793,16 @@ inline auto laplace_marginal_density_est(LLFun&& ll_fun, LLTupleArgs&& ll_args,
768793 * \msg_arg
769794 * @return the log maginal density, p(y | phi)
770795 */
771- template <typename LLFun, typename LLTupleArgs, typename CovarFun,
772- typename ThetaVec, typename CovarArgs,
773- require_t <is_all_arithmetic_scalar<ThetaVec, CovarArgs,
774- LLTupleArgs>>* = nullptr ,
775- require_eigen_vector_t <ThetaVec>* = nullptr >
776- inline double laplace_marginal_density (LLFun&& ll_fun, LLTupleArgs&& ll_args,
777- ThetaVec&& theta_0,
778- CovarFun&& covariance_function,
779- CovarArgs&& covar_args,
780- const laplace_options& options,
781- std::ostream* msgs) {
796+ template <
797+ typename LLFun, typename LLTupleArgs, typename CovarFun, typename CovarArgs,
798+ bool InitTheta,
799+ require_t <is_all_arithmetic_scalar<CovarArgs, LLTupleArgs>>* = nullptr >
800+ inline double laplace_marginal_density (
801+ LLFun&& ll_fun, LLTupleArgs&& ll_args, CovarFun&& covariance_function,
802+ CovarArgs&& covar_args, const laplace_options<InitTheta>& options,
803+ std::ostream* msgs) {
782804 return internal::laplace_marginal_density_est (
783805 std::forward<LLFun>(ll_fun), std::forward<LLTupleArgs>(ll_args),
784- std::forward<ThetaVec>(theta_0),
785806 std::forward<CovarFun>(covariance_function),
786807 std::forward<CovarArgs>(covar_args), options, msgs)
787808 .lmd ;
@@ -1014,16 +1035,13 @@ inline void reverse_pass_collect_adjoints(var ret, Output&& output,
10141035 * \msg_arg
10151036 * @return the log maginal density, p(y | phi)
10161037 */
1017- template <
1018- typename LLFun, typename LLTupleArgs, typename CovarFun, typename ThetaVec,
1019- typename CovarArgs,
1020- require_t <is_any_var_scalar<ThetaVec, LLTupleArgs, CovarArgs>>* = nullptr ,
1021- require_eigen_vector_t <ThetaVec>* = nullptr >
1038+ template <typename LLFun, typename LLTupleArgs, typename CovarFun,
1039+ typename CovarArgs, bool InitTheta,
1040+ require_t <is_any_var_scalar<LLTupleArgs, CovarArgs>>* = nullptr >
10221041inline auto laplace_marginal_density (const LLFun& ll_fun, LLTupleArgs&& ll_args,
1023- ThetaVec&& theta_0,
10241042 CovarFun&& covariance_function,
10251043 CovarArgs&& covar_args,
1026- const laplace_options& options,
1044+ const laplace_options<InitTheta> & options,
10271045 std::ostream* msgs) {
10281046 auto covar_args_refs = to_ref (std::forward<CovarArgs>(covar_args));
10291047 auto ll_args_refs = to_ref (std::forward<LLTupleArgs>(ll_args));
@@ -1034,13 +1052,7 @@ inline auto laplace_marginal_density(const LLFun& ll_fun, LLTupleArgs&& ll_args,
10341052 double lmd = 0.0 ;
10351053 {
10361054 nested_rev_autodiff nested;
1037- // Solver 1, 2
1038- arena_t <Eigen::MatrixXd> R (theta_0.size (), theta_0.size ());
1039- // Solver 3
1040- arena_t <Eigen::MatrixXd> LU_solve_covariance;
1041- // Solver 1, 2, 3
1042- arena_t <promote_scalar_t <double , plain_type_t <std::decay_t <ThetaVec>>>> s2 (
1043- theta_0.size ());
1055+
10441056 // Make one hard copy here
10451057 using laplace_likelihood::internal::conditional_copy_and_promote;
10461058 using laplace_likelihood::internal::COPY_TYPE;
@@ -1049,8 +1061,16 @@ inline auto laplace_marginal_density(const LLFun& ll_fun, LLTupleArgs&& ll_args,
10491061 ll_args_refs);
10501062
10511063 auto md_est = internal::laplace_marginal_density_est (
1052- ll_fun, ll_args_copy, value_of (theta_0), covariance_function,
1053- value_of (covar_args_refs), options, msgs);
1064+ ll_fun, ll_args_copy, covariance_function, value_of (covar_args_refs),
1065+ options, msgs);
1066+
1067+ // Solver 1, 2
1068+ arena_t <Eigen::MatrixXd> R (md_est.theta .size (), md_est.theta .size ());
1069+ // Solver 3
1070+ arena_t <Eigen::MatrixXd> LU_solve_covariance;
1071+ // Solver 1, 2, 3
1072+ arena_t <Eigen::VectorXd> s2 (md_est.theta .size ());
1073+
10541074 // Return references to var types
10551075 auto ll_args_filter = internal::filter_var_scalar_types (ll_args_copy);
10561076 stan::math::for_each (
0 commit comments