diff --git a/stan/math/fwd/fun/hypergeometric_1F0.hpp b/stan/math/fwd/fun/hypergeometric_1F0.hpp index b304d5c8381..5337f411f00 100644 --- a/stan/math/fwd/fun/hypergeometric_1F0.hpp +++ b/stan/math/fwd/fun/hypergeometric_1F0.hpp @@ -31,10 +31,10 @@ namespace math { template , require_all_stan_scalar_t* = nullptr, require_any_fvar_t* = nullptr> -FvarT hypergeometric_1f0(const Ta& a, const Tz& z) { +FvarT hypergeometric_1F0(const Ta& a, const Tz& z) { partials_type_t a_val = value_of(a); partials_type_t z_val = value_of(z); - FvarT rtn = FvarT(hypergeometric_1f0(a_val, z_val), 0.0); + FvarT rtn = FvarT(hypergeometric_1F0(a_val, z_val), 0.0); if (!is_constant_all::value) { rtn.d_ += forward_as(a).d() * -rtn.val() * log1m(z_val); } diff --git a/stan/math/fwd/fun/hypergeometric_2F1.hpp b/stan/math/fwd/fun/hypergeometric_2F1.hpp index 9e450cda01f..afaf7cef896 100644 --- a/stan/math/fwd/fun/hypergeometric_2F1.hpp +++ b/stan/math/fwd/fun/hypergeometric_2F1.hpp @@ -30,11 +30,11 @@ namespace math { template * = nullptr, require_any_fvar_t* = nullptr> -inline return_type_t hypergeometric_2F1(const Ta1& a1, +inline return_type_t hypergeometric_2F1(const Ta1& a1, const Ta2& a2, const Tb& b, const Tz& z) { - using fvar_t = return_type_t; + using fvar_t = return_type_t; auto a1_val = value_of(a1); auto a2_val = value_of(a2); diff --git a/stan/math/fwd/fun/hypergeometric_pFq.hpp b/stan/math/fwd/fun/hypergeometric_pFq.hpp index 3165878ffcf..85b8cc623aa 100644 --- a/stan/math/fwd/fun/hypergeometric_pFq.hpp +++ b/stan/math/fwd/fun/hypergeometric_pFq.hpp @@ -7,6 +7,8 @@ #include #include #include +#include +#include namespace stan { namespace math { @@ -30,33 +32,27 @@ template ::value, require_all_vector_t* = nullptr, require_fvar_t* = nullptr> -inline FvarT hypergeometric_pFq(const Ta& a, const Tb& b, const Tz& z) { - using PartialsT = partials_type_t; - using ARefT = ref_type_t; - using BRefT = ref_type_t; - - ARefT a_ref = a; - BRefT b_ref = b; +inline FvarT hypergeometric_pFq(Ta&& a, Tb&& b, Tz&& z) { + auto&& a_ref = to_ref(as_column_vector_or_scalar(a)); + auto&& b_ref = to_ref(as_column_vector_or_scalar(b)); auto&& a_val = value_of(a_ref); auto&& b_val = value_of(b_ref); auto&& z_val = value_of(z); - PartialsT pfq_val = hypergeometric_pFq(a_val, b_val, z_val); + + partials_type_t pfq_val = hypergeometric_pFq(a_val, b_val, z_val); auto grad_tuple = grad_pFq(pfq_val, a_val, b_val, z_val); FvarT rtn = FvarT(pfq_val, 0.0); - if (grad_a) { - rtn.d_ += dot_product(forward_as>(a_ref).d(), - std::get<0>(grad_tuple)); + if constexpr (grad_a) { + rtn.d_ += dot_product(a_ref.d(), std::get<0>(grad_tuple)); } - if (grad_b) { - rtn.d_ += dot_product(forward_as>(b_ref).d(), - std::get<1>(grad_tuple)); + if constexpr (grad_b) { + rtn.d_ += dot_product(b_ref.d(), std::get<1>(grad_tuple)); } - if (grad_z) { - rtn.d_ += forward_as>(z).d_ - * std::get<2>(grad_tuple); + if constexpr (grad_z) { + rtn.d_ += z.d_ * std::get<2>(grad_tuple); } return rtn; diff --git a/stan/math/prim/fun/hypergeometric_1F0.hpp b/stan/math/prim/fun/hypergeometric_1F0.hpp index 219cf4a0eb1..2833e503887 100644 --- a/stan/math/prim/fun/hypergeometric_1F0.hpp +++ b/stan/math/prim/fun/hypergeometric_1F0.hpp @@ -28,9 +28,8 @@ namespace math { * @return Hypergeometric 1F0 function */ template * = nullptr> -return_type_t hypergeometric_1f0(const Ta& a, const Tz& z) { - constexpr const char* function = "hypergeometric_1f0"; - check_less("hypergeometric_1f0", "abs(z)", std::fabs(z), 1.0); +return_type_t hypergeometric_1F0(const Ta& a, const Tz& z) { + check_less("hypergeometric_1F0", "abs(z)", std::fabs(z), 1.0); return boost::math::hypergeometric_1F0(a, z, boost_policy_t<>()); } diff --git a/stan/math/prim/fun/hypergeometric_2F1.hpp b/stan/math/prim/fun/hypergeometric_2F1.hpp index 2dbae6bc418..ae327e033f1 100644 --- a/stan/math/prim/fun/hypergeometric_2F1.hpp +++ b/stan/math/prim/fun/hypergeometric_2F1.hpp @@ -43,7 +43,7 @@ namespace internal { * @return Gauss hypergeometric function */ template >, + typename RtnT = boost::optional>, require_all_arithmetic_t* = nullptr> inline RtnT hyper_2F1_special_cases(const Ta1& a1, const Ta2& a2, const Tb& b, const Tz& z) { @@ -148,10 +148,10 @@ inline RtnT hyper_2F1_special_cases(const Ta1& a1, const Ta2& a2, const Tb& b, * @return Gauss hypergeometric function */ template , + typename ScalarT = return_type_t, typename OptT = boost::optional, require_all_arithmetic_t* = nullptr> -inline return_type_t hypergeometric_2F1(const Ta1& a1, +inline return_type_t hypergeometric_2F1(const Ta1& a1, const Ta2& a2, const Tb& b, const Tz& z) { diff --git a/stan/math/prim/fun/hypergeometric_3F2.hpp b/stan/math/prim/fun/hypergeometric_3F2.hpp index 68479d1cbcc..0cff50e8daa 100644 --- a/stan/math/prim/fun/hypergeometric_3F2.hpp +++ b/stan/math/prim/fun/hypergeometric_3F2.hpp @@ -114,13 +114,20 @@ template * = nullptr, require_stan_scalar_t* = nullptr> inline auto hypergeometric_3F2(const Ta& a, const Tb& b, const Tz& z) { - check_3F2_converges("hypergeometric_3F2", a[0], a[1], a[2], b[0], b[1], z); + check_size_match("hypergeometric_3F2", "a", a.size(), "3", 3); + check_size_match("hypergeometric_3F2", "b", b.size(), "2", 2); + + auto a_ref = to_vector(a); + auto b_ref = to_vector(b); + + check_3F2_converges("hypergeometric_3F2", a_ref[0], a_ref[1], a_ref[2], + b_ref[0], b_ref[1], z); // Boost's pFq throws convergence errors in some cases, fallback to naive // infinite-sum approach (tests pass for these) - if (z == 1.0 && (sum(b) - sum(a)) < 0.0) { - return internal::hypergeometric_3F2_infsum(a, b, z); + if (z == 1.0 && (sum(b_ref) - sum(a_ref)) < 0.0) { + return internal::hypergeometric_3F2_infsum(a_ref, b_ref, z); } - return hypergeometric_pFq(to_vector(a), to_vector(b), z); + return hypergeometric_pFq(a_ref, b_ref, z); } /** diff --git a/stan/math/prim/fun/hypergeometric_pFq.hpp b/stan/math/prim/fun/hypergeometric_pFq.hpp index 388a7113f5d..c2b30610314 100644 --- a/stan/math/prim/fun/hypergeometric_pFq.hpp +++ b/stan/math/prim/fun/hypergeometric_pFq.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include namespace stan { @@ -14,10 +15,6 @@ namespace math { * input arguments: * \f$_pF_q(a_1,...,a_p;b_1,...,b_q;z)\f$ * - * This function is not intended to be exposed to end users, only - * used for p & q values that are stable with the grad_pFq - * implementation. - * * See 'grad_pFq.hpp' for the derivatives wrt each parameter * * @param[in] a Vector of 'a' arguments to function @@ -26,7 +23,7 @@ namespace math { * @return Generalized hypergeometric function */ template * = nullptr, + require_all_vector_st* = nullptr, require_arithmetic_t* = nullptr> return_type_t hypergeometric_pFq(const Ta& a, const Tb& b, const Tz& z) { @@ -47,8 +44,9 @@ return_type_t hypergeometric_pFq(const Ta& a, const Tb& b, std::stringstream msg; msg << "hypergeometric function pFq does not meet convergence " << "conditions with given arguments. " - << "a: " << a_ref << ", b: " << b_ref << ", " - << ", z: " << z; + << "a: " << to_row_vector(a_ref) << ", " + << "b: " << to_row_vector(b_ref) << ", " + << "z: " << z; throw std::domain_error(msg.str()); } diff --git a/stan/math/rev/fun/hypergeometric_1F0.hpp b/stan/math/rev/fun/hypergeometric_1F0.hpp index 10774fdc915..1817afe61a4 100644 --- a/stan/math/rev/fun/hypergeometric_1F0.hpp +++ b/stan/math/rev/fun/hypergeometric_1F0.hpp @@ -31,10 +31,10 @@ namespace math { template * = nullptr, require_any_var_t* = nullptr> -var hypergeometric_1f0(const Ta& a, const Tz& z) { +var hypergeometric_1F0(const Ta& a, const Tz& z) { double a_val = value_of(a); double z_val = value_of(z); - double rtn = hypergeometric_1f0(a_val, z_val); + double rtn = hypergeometric_1F0(a_val, z_val); return make_callback_var(rtn, [rtn, a, z, a_val, z_val](auto& vi) mutable { if (!is_constant_all::value) { forward_as(a).adj() += vi.adj() * -rtn * log1m(z_val); diff --git a/stan/math/rev/fun/hypergeometric_2F1.hpp b/stan/math/rev/fun/hypergeometric_2F1.hpp index 17b72518656..f06b99604c7 100644 --- a/stan/math/rev/fun/hypergeometric_2F1.hpp +++ b/stan/math/rev/fun/hypergeometric_2F1.hpp @@ -29,7 +29,7 @@ namespace math { template * = nullptr, require_any_var_t* = nullptr> -inline return_type_t hypergeometric_2F1(const Ta1& a1, +inline return_type_t hypergeometric_2F1(const Ta1& a1, const Ta2& a2, const Tb& b, const Tz& z) { diff --git a/stan/math/rev/fun/hypergeometric_pFq.hpp b/stan/math/rev/fun/hypergeometric_pFq.hpp index 140c24adc05..008d91e11e8 100644 --- a/stan/math/rev/fun/hypergeometric_pFq.hpp +++ b/stan/math/rev/fun/hypergeometric_pFq.hpp @@ -3,6 +3,7 @@ #include #include +#include #include #include @@ -25,27 +26,24 @@ template ::value, bool grad_b = !is_constant::value, bool grad_z = !is_constant::value, - require_all_matrix_t* = nullptr, + require_all_vector_t* = nullptr, require_return_type_t* = nullptr> -inline var hypergeometric_pFq(const Ta& a, const Tb& b, const Tz& z) { - arena_t arena_a = a; - arena_t arena_b = b; - auto pfq_val = hypergeometric_pFq(a.val(), b.val(), value_of(z)); +inline var hypergeometric_pFq(Ta&& a, Tb&& b, Tz&& z) { + auto&& arena_a = to_arena(as_column_vector_or_scalar(std::forward(a))); + auto&& arena_b = to_arena(as_column_vector_or_scalar(std::forward(b))); + auto pfq_val = hypergeometric_pFq(arena_a.val(), arena_b.val(), value_of(z)); return make_callback_var( pfq_val, [arena_a, arena_b, z, pfq_val](auto& vi) mutable { auto grad_tuple = grad_pFq( pfq_val, arena_a.val(), arena_b.val(), value_of(z)); if constexpr (grad_a) { - forward_as>(arena_a).adj() - += vi.adj() * std::get<0>(grad_tuple); + arena_a.adj() += vi.adj() * std::get<0>(grad_tuple); } if constexpr (grad_b) { - forward_as>(arena_b).adj() - += vi.adj() * std::get<1>(grad_tuple); + arena_b.adj() += vi.adj() * std::get<1>(grad_tuple); } if constexpr (grad_z) { - forward_as>(z).adj() - += vi.adj() * std::get<2>(grad_tuple); + z.adj() += vi.adj() * std::get<2>(grad_tuple); } }); } diff --git a/test/sig_utils.py b/test/sig_utils.py index 20c013f01b3..2018f6e4809 100644 --- a/test/sig_utils.py +++ b/test/sig_utils.py @@ -153,6 +153,7 @@ def get_cpp_type(stan_type): ignored = [ "std_normal_qf", # synonym for inv_Phi "if_else", + "hypergeometric_3F2", # requires arguments of specific lengths ] # these are all slight renames compared to stan math diff --git a/test/unit/math/mix/fun/hypergeometric_1F0_test.cpp b/test/unit/math/mix/fun/hypergeometric_1F0_test.cpp index a1089461847..184dc91feda 100644 --- a/test/unit/math/mix/fun/hypergeometric_1F0_test.cpp +++ b/test/unit/math/mix/fun/hypergeometric_1F0_test.cpp @@ -1,9 +1,9 @@ #include -TEST(mathMixScalFun, hypergeometric_1f0) { +TEST(mathMixScalFun, hypergeometric_1F0) { auto f = [](const auto& x1, const auto& x2) { - using stan::math::hypergeometric_1f0; - return hypergeometric_1f0(x1, x2); + using stan::math::hypergeometric_1F0; + return hypergeometric_1F0(x1, x2); }; stan::test::expect_ad(f, 5, 0.3); diff --git a/test/unit/math/mix/fun/hypergeometric_pFq_test.cpp b/test/unit/math/mix/fun/hypergeometric_pFq_test.cpp index ba498378459..c5e2b5bad4e 100644 --- a/test/unit/math/mix/fun/hypergeometric_pFq_test.cpp +++ b/test/unit/math/mix/fun/hypergeometric_pFq_test.cpp @@ -2,6 +2,9 @@ #include TEST(mathMixScalFun, hyper_2f2) { + using stan::math::to_array_1d; + using stan::math::to_row_vector; + auto f = [](const auto& a, const auto& b, const auto& z) { using stan::math::hypergeometric_pFq; return hypergeometric_pFq(a, b, z); @@ -14,6 +17,9 @@ TEST(mathMixScalFun, hyper_2f2) { double z = 4; stan::test::expect_ad(f, in1, in2, z); + stan::test::expect_ad(f, to_array_1d(in1), to_row_vector(in2), z); + stan::test::expect_ad(f, to_row_vector(in1), to_array_1d(in2), z); + stan::test::expect_ad(f, to_array_1d(in1), to_array_1d(in2), z); } TEST(mathMixScalFun, hyper_2f3) { diff --git a/test/unit/math/prim/fun/hypergeometric_1F0_test.cpp b/test/unit/math/prim/fun/hypergeometric_1F0_test.cpp index 6edc38bb99d..ae935f32145 100644 --- a/test/unit/math/prim/fun/hypergeometric_1F0_test.cpp +++ b/test/unit/math/prim/fun/hypergeometric_1F0_test.cpp @@ -3,21 +3,21 @@ #include #include -TEST(MathFunctions, hypergeometric_1f0Double) { - using stan::math::hypergeometric_1f0; +TEST(MathFunctions, hypergeometric_1F0Double) { + using stan::math::hypergeometric_1F0; using stan::math::inv; - EXPECT_FLOAT_EQ(4.62962962963, hypergeometric_1f0(3, 0.4)); - EXPECT_FLOAT_EQ(0.510204081633, hypergeometric_1f0(2, -0.4)); - EXPECT_FLOAT_EQ(300.906354890, hypergeometric_1f0(16.0, 0.3)); - EXPECT_FLOAT_EQ(0.531441, hypergeometric_1f0(-6.0, 0.1)); + EXPECT_FLOAT_EQ(4.62962962963, hypergeometric_1F0(3, 0.4)); + EXPECT_FLOAT_EQ(0.510204081633, hypergeometric_1F0(2, -0.4)); + EXPECT_FLOAT_EQ(300.906354890, hypergeometric_1F0(16.0, 0.3)); + EXPECT_FLOAT_EQ(0.531441, hypergeometric_1F0(-6.0, 0.1)); } -TEST(MathFunctions, hypergeometric_1f0_throw) { - using stan::math::hypergeometric_1f0; +TEST(MathFunctions, hypergeometric_1F0_throw) { + using stan::math::hypergeometric_1F0; - EXPECT_THROW(hypergeometric_1f0(2.1, 1.0), std::domain_error); - EXPECT_THROW(hypergeometric_1f0(0.5, 1.5), std::domain_error); - EXPECT_THROW(hypergeometric_1f0(0.5, -1.0), std::domain_error); - EXPECT_THROW(hypergeometric_1f0(0.5, -1.5), std::domain_error); + EXPECT_THROW(hypergeometric_1F0(2.1, 1.0), std::domain_error); + EXPECT_THROW(hypergeometric_1F0(0.5, 1.5), std::domain_error); + EXPECT_THROW(hypergeometric_1F0(0.5, -1.0), std::domain_error); + EXPECT_THROW(hypergeometric_1F0(0.5, -1.5), std::domain_error); } diff --git a/test/unit/math/prim/fun/hypergeometric_3F2_test.cpp b/test/unit/math/prim/fun/hypergeometric_3F2_test.cpp index 2a2910165aa..1ca327223d5 100644 --- a/test/unit/math/prim/fun/hypergeometric_3F2_test.cpp +++ b/test/unit/math/prim/fun/hypergeometric_3F2_test.cpp @@ -3,8 +3,16 @@ // converge TEST(MathPrimScalFun, F32_converges_by_z) { - EXPECT_NEAR(2.5, - stan::math::hypergeometric_3F2({1.0, 1.0, 1.0}, {1.0, 1.0}, 0.6), + using stan::math::hypergeometric_3F2; + using stan::math::to_row_vector; + using stan::math::to_vector; + std::vector a = {1.0, 1.0, 1.0}; + std::vector b = {1.0, 1.0}; + double z = 0.6; + + EXPECT_NEAR(2.5, hypergeometric_3F2(a, b, z), 1e-8); + EXPECT_NEAR(2.5, hypergeometric_3F2(to_vector(a), to_vector(b), z), 1e-8); + EXPECT_NEAR(2.5, hypergeometric_3F2(to_row_vector(a), to_row_vector(b), z), 1e-8); } // terminate by zero numerator, no sign-flip diff --git a/test/unit/math/prim/fun/hypergeometric_pFq_test.cpp b/test/unit/math/prim/fun/hypergeometric_pFq_test.cpp index e62c065932c..3e455991c0c 100644 --- a/test/unit/math/prim/fun/hypergeometric_pFq_test.cpp +++ b/test/unit/math/prim/fun/hypergeometric_pFq_test.cpp @@ -4,6 +4,8 @@ TEST(MathFunctions, hypergeometric_pFq_values) { using Eigen::VectorXd; using stan::math::hypergeometric_pFq; + using stan::math::to_array_1d; + using stan::math::to_row_vector; VectorXd a(2); VectorXd b(2); @@ -12,6 +14,10 @@ TEST(MathFunctions, hypergeometric_pFq_values) { double z = 2; EXPECT_FLOAT_EQ(3.8420514314107791, hypergeometric_pFq(a, b, z)); + EXPECT_FLOAT_EQ(3.8420514314107791, + hypergeometric_pFq(to_row_vector(a), to_row_vector(b), z)); + EXPECT_FLOAT_EQ(3.8420514314107791, + hypergeometric_pFq(to_array_1d(a), to_array_1d(b), z)); a << 6, 4; b << 3, 1;