Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions stan/math/fwd/fun/hypergeometric_1F0.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ namespace math {
template <typename Ta, typename Tz, typename FvarT = return_type_t<Ta, Tz>,
require_all_stan_scalar_t<Ta, Tz>* = nullptr,
require_any_fvar_t<Ta, Tz>* = nullptr>
FvarT hypergeometric_1f0(const Ta& a, const Tz& z) {
FvarT hypergeometric_1F0(const Ta& a, const Tz& z) {
partials_type_t<Ta> a_val = value_of(a);
partials_type_t<Tz> z_val = value_of(z);
FvarT rtn = FvarT(hypergeometric_1f0(a_val, z_val), 0.0);
FvarT rtn = FvarT(hypergeometric_1F0(a_val, z_val), 0.0);
if (!is_constant_all<Ta>::value) {
rtn.d_ += forward_as<FvarT>(a).d() * -rtn.val() * log1m(z_val);
}
Expand Down
4 changes: 2 additions & 2 deletions stan/math/fwd/fun/hypergeometric_2F1.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ namespace math {
template <typename Ta1, typename Ta2, typename Tb, typename Tz,
require_all_stan_scalar_t<Ta1, Ta2, Tb, Tz>* = nullptr,
require_any_fvar_t<Ta1, Ta2, Tb, Tz>* = nullptr>
inline return_type_t<Ta1, Ta1, Tb, Tz> hypergeometric_2F1(const Ta1& a1,
inline return_type_t<Ta1, Ta2, Tb, Tz> hypergeometric_2F1(const Ta1& a1,
const Ta2& a2,
const Tb& b,
const Tz& z) {
using fvar_t = return_type_t<Ta1, Ta1, Tb, Tz>;
using fvar_t = return_type_t<Ta1, Ta2, Tb, Tz>;

auto a1_val = value_of(a1);
auto a2_val = value_of(a2);
Expand Down
30 changes: 13 additions & 17 deletions stan/math/fwd/fun/hypergeometric_pFq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include <stan/math/prim/fun/dot_product.hpp>
#include <stan/math/prim/fun/grad_pFq.hpp>
#include <stan/math/prim/fun/hypergeometric_pFq.hpp>
#include <stan/math/prim/fun/as_column_vector_or_scalar.hpp>
#include <stan/math/prim/fun/to_ref.hpp>

namespace stan {
namespace math {
Expand All @@ -30,33 +32,27 @@ template <typename Ta, typename Tb, typename Tz,
bool grad_z = !is_constant<Tz>::value,
require_all_vector_t<Ta, Tb>* = nullptr,
require_fvar_t<FvarT>* = nullptr>
inline FvarT hypergeometric_pFq(const Ta& a, const Tb& b, const Tz& z) {
using PartialsT = partials_type_t<FvarT>;
using ARefT = ref_type_t<Ta>;
using BRefT = ref_type_t<Tb>;

ARefT a_ref = a;
BRefT b_ref = b;
inline FvarT hypergeometric_pFq(Ta&& a, Tb&& b, Tz&& z) {
auto&& a_ref = to_ref(as_column_vector_or_scalar(a));
auto&& b_ref = to_ref(as_column_vector_or_scalar(b));
auto&& a_val = value_of(a_ref);
auto&& b_val = value_of(b_ref);
auto&& z_val = value_of(z);
PartialsT pfq_val = hypergeometric_pFq(a_val, b_val, z_val);

partials_type_t<FvarT> pfq_val = hypergeometric_pFq(a_val, b_val, z_val);
auto grad_tuple
= grad_pFq<grad_a, grad_b, grad_z>(pfq_val, a_val, b_val, z_val);

FvarT rtn = FvarT(pfq_val, 0.0);

if (grad_a) {
rtn.d_ += dot_product(forward_as<promote_scalar_t<FvarT, ARefT>>(a_ref).d(),
std::get<0>(grad_tuple));
if constexpr (grad_a) {
rtn.d_ += dot_product(a_ref.d(), std::get<0>(grad_tuple));
}
if (grad_b) {
rtn.d_ += dot_product(forward_as<promote_scalar_t<FvarT, BRefT>>(b_ref).d(),
std::get<1>(grad_tuple));
if constexpr (grad_b) {
rtn.d_ += dot_product(b_ref.d(), std::get<1>(grad_tuple));
}
if (grad_z) {
rtn.d_ += forward_as<promote_scalar_t<FvarT, Tz>>(z).d_
* std::get<2>(grad_tuple);
if constexpr (grad_z) {
rtn.d_ += z.d_ * std::get<2>(grad_tuple);
}

return rtn;
Expand Down
5 changes: 2 additions & 3 deletions stan/math/prim/fun/hypergeometric_1F0.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@ namespace math {
* @return Hypergeometric 1F0 function
*/
template <typename Ta, typename Tz, require_all_arithmetic_t<Ta, Tz>* = nullptr>
return_type_t<Ta, Tz> hypergeometric_1f0(const Ta& a, const Tz& z) {
constexpr const char* function = "hypergeometric_1f0";
check_less("hypergeometric_1f0", "abs(z)", std::fabs(z), 1.0);
return_type_t<Ta, Tz> hypergeometric_1F0(const Ta& a, const Tz& z) {
check_less("hypergeometric_1F0", "abs(z)", std::fabs(z), 1.0);

return boost::math::hypergeometric_1F0(a, z, boost_policy_t<>());
}
Expand Down
6 changes: 3 additions & 3 deletions stan/math/prim/fun/hypergeometric_2F1.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ namespace internal {
* @return Gauss hypergeometric function
*/
template <typename Ta1, typename Ta2, typename Tb, typename Tz,
typename RtnT = boost::optional<return_type_t<Ta1, Ta1, Tb, Tz>>,
typename RtnT = boost::optional<return_type_t<Ta1, Ta2, Tb, Tz>>,
require_all_arithmetic_t<Ta1, Ta2, Tb, Tz>* = nullptr>
inline RtnT hyper_2F1_special_cases(const Ta1& a1, const Ta2& a2, const Tb& b,
const Tz& z) {
Expand Down Expand Up @@ -148,10 +148,10 @@ inline RtnT hyper_2F1_special_cases(const Ta1& a1, const Ta2& a2, const Tb& b,
* @return Gauss hypergeometric function
*/
template <typename Ta1, typename Ta2, typename Tb, typename Tz,
typename ScalarT = return_type_t<Ta1, Ta1, Tb, Tz>,
typename ScalarT = return_type_t<Ta1, Ta2, Tb, Tz>,
typename OptT = boost::optional<ScalarT>,
require_all_arithmetic_t<Ta1, Ta2, Tb, Tz>* = nullptr>
inline return_type_t<Ta1, Ta1, Tb, Tz> hypergeometric_2F1(const Ta1& a1,
inline return_type_t<Ta1, Ta2, Tb, Tz> hypergeometric_2F1(const Ta1& a1,
const Ta2& a2,
const Tb& b,
const Tz& z) {
Expand Down
15 changes: 11 additions & 4 deletions stan/math/prim/fun/hypergeometric_3F2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,20 @@ template <typename Ta, typename Tb, typename Tz,
require_all_vector_t<Ta, Tb>* = nullptr,
require_stan_scalar_t<Tz>* = nullptr>
inline auto hypergeometric_3F2(const Ta& a, const Tb& b, const Tz& z) {
check_3F2_converges("hypergeometric_3F2", a[0], a[1], a[2], b[0], b[1], z);
check_size_match("hypergeometric_3F2", "a", a.size(), "3", 3);
check_size_match("hypergeometric_3F2", "b", b.size(), "2", 2);

auto a_ref = to_vector(a);
auto b_ref = to_vector(b);

check_3F2_converges("hypergeometric_3F2", a_ref[0], a_ref[1], a_ref[2],
b_ref[0], b_ref[1], z);
// Boost's pFq throws convergence errors in some cases, fallback to naive
// infinite-sum approach (tests pass for these)
if (z == 1.0 && (sum(b) - sum(a)) < 0.0) {
return internal::hypergeometric_3F2_infsum(a, b, z);
if (z == 1.0 && (sum(b_ref) - sum(a_ref)) < 0.0) {
return internal::hypergeometric_3F2_infsum(a_ref, b_ref, z);
}
return hypergeometric_pFq(to_vector(a), to_vector(b), z);
return hypergeometric_pFq(a_ref, b_ref, z);
}

/**
Expand Down
12 changes: 5 additions & 7 deletions stan/math/prim/fun/hypergeometric_pFq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err/check_not_nan.hpp>
#include <stan/math/prim/err/check_finite.hpp>
#include <stan/math/prim/fun/to_row_vector.hpp>
#include <boost/math/special_functions/hypergeometric_pFq.hpp>

namespace stan {
Expand All @@ -14,10 +15,6 @@ namespace math {
* input arguments:
* \f$_pF_q(a_1,...,a_p;b_1,...,b_q;z)\f$
*
* This function is not intended to be exposed to end users, only
* used for p & q values that are stable with the grad_pFq
* implementation.
*
* See 'grad_pFq.hpp' for the derivatives wrt each parameter
*
* @param[in] a Vector of 'a' arguments to function
Expand All @@ -26,7 +23,7 @@ namespace math {
* @return Generalized hypergeometric function
*/
template <typename Ta, typename Tb, typename Tz,
require_all_eigen_st<std::is_arithmetic, Ta, Tb>* = nullptr,
require_all_vector_st<std::is_arithmetic, Ta, Tb>* = nullptr,
require_arithmetic_t<Tz>* = nullptr>
return_type_t<Ta, Tb, Tz> hypergeometric_pFq(const Ta& a, const Tb& b,
const Tz& z) {
Expand All @@ -47,8 +44,9 @@ return_type_t<Ta, Tb, Tz> hypergeometric_pFq(const Ta& a, const Tb& b,
std::stringstream msg;
msg << "hypergeometric function pFq does not meet convergence "
<< "conditions with given arguments. "
<< "a: " << a_ref << ", b: " << b_ref << ", "
<< ", z: " << z;
<< "a: " << to_row_vector(a_ref) << ", "
<< "b: " << to_row_vector(b_ref) << ", "
<< "z: " << z;
throw std::domain_error(msg.str());
}

Expand Down
4 changes: 2 additions & 2 deletions stan/math/rev/fun/hypergeometric_1F0.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ namespace math {
template <typename Ta, typename Tz,
require_all_stan_scalar_t<Ta, Tz>* = nullptr,
require_any_var_t<Ta, Tz>* = nullptr>
var hypergeometric_1f0(const Ta& a, const Tz& z) {
var hypergeometric_1F0(const Ta& a, const Tz& z) {
double a_val = value_of(a);
double z_val = value_of(z);
double rtn = hypergeometric_1f0(a_val, z_val);
double rtn = hypergeometric_1F0(a_val, z_val);
return make_callback_var(rtn, [rtn, a, z, a_val, z_val](auto& vi) mutable {
if (!is_constant_all<Ta>::value) {
forward_as<var>(a).adj() += vi.adj() * -rtn * log1m(z_val);
Expand Down
2 changes: 1 addition & 1 deletion stan/math/rev/fun/hypergeometric_2F1.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace math {
template <typename Ta1, typename Ta2, typename Tb, typename Tz,
require_all_stan_scalar_t<Ta1, Ta2, Tb, Tz>* = nullptr,
require_any_var_t<Ta1, Ta2, Tb, Tz>* = nullptr>
inline return_type_t<Ta1, Ta1, Tb, Tz> hypergeometric_2F1(const Ta1& a1,
inline return_type_t<Ta1, Ta2, Tb, Tz> hypergeometric_2F1(const Ta1& a1,
const Ta2& a2,
const Tb& b,
const Tz& z) {
Expand Down
20 changes: 9 additions & 11 deletions stan/math/rev/fun/hypergeometric_pFq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <stan/math/rev/core.hpp>
#include <stan/math/rev/meta.hpp>
#include <stan/math/prim/fun/as_column_vector_or_scalar.hpp>
#include <stan/math/prim/fun/grad_pFq.hpp>
#include <stan/math/prim/fun/hypergeometric_pFq.hpp>

Expand All @@ -25,27 +26,24 @@ template <typename Ta, typename Tb, typename Tz,
bool grad_a = !is_constant<Ta>::value,
bool grad_b = !is_constant<Tb>::value,
bool grad_z = !is_constant<Tz>::value,
require_all_matrix_t<Ta, Tb>* = nullptr,
require_all_vector_t<Ta, Tb>* = nullptr,
require_return_type_t<is_var, Ta, Tb, Tz>* = nullptr>
inline var hypergeometric_pFq(const Ta& a, const Tb& b, const Tz& z) {
arena_t<Ta> arena_a = a;
arena_t<Tb> arena_b = b;
auto pfq_val = hypergeometric_pFq(a.val(), b.val(), value_of(z));
inline var hypergeometric_pFq(Ta&& a, Tb&& b, Tz&& z) {
auto&& arena_a = to_arena(as_column_vector_or_scalar(std::forward<Ta>(a)));
auto&& arena_b = to_arena(as_column_vector_or_scalar(std::forward<Tb>(b)));
auto pfq_val = hypergeometric_pFq(arena_a.val(), arena_b.val(), value_of(z));
return make_callback_var(
pfq_val, [arena_a, arena_b, z, pfq_val](auto& vi) mutable {
auto grad_tuple = grad_pFq<grad_a, grad_b, grad_z>(
pfq_val, arena_a.val(), arena_b.val(), value_of(z));
if constexpr (grad_a) {
forward_as<promote_scalar_t<var, Ta>>(arena_a).adj()
+= vi.adj() * std::get<0>(grad_tuple);
arena_a.adj() += vi.adj() * std::get<0>(grad_tuple);
}
if constexpr (grad_b) {
forward_as<promote_scalar_t<var, Tb>>(arena_b).adj()
+= vi.adj() * std::get<1>(grad_tuple);
arena_b.adj() += vi.adj() * std::get<1>(grad_tuple);
}
if constexpr (grad_z) {
forward_as<promote_scalar_t<var, Tz>>(z).adj()
+= vi.adj() * std::get<2>(grad_tuple);
z.adj() += vi.adj() * std::get<2>(grad_tuple);
}
});
}
Expand Down
1 change: 1 addition & 0 deletions test/sig_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def get_cpp_type(stan_type):
ignored = [
"std_normal_qf", # synonym for inv_Phi
"if_else",
"hypergeometric_3F2", # requires arguments of specific lengths
]

# these are all slight renames compared to stan math
Expand Down
6 changes: 3 additions & 3 deletions test/unit/math/mix/fun/hypergeometric_1F0_test.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#include <test/unit/math/test_ad.hpp>

TEST(mathMixScalFun, hypergeometric_1f0) {
TEST(mathMixScalFun, hypergeometric_1F0) {
auto f = [](const auto& x1, const auto& x2) {
using stan::math::hypergeometric_1f0;
return hypergeometric_1f0(x1, x2);
using stan::math::hypergeometric_1F0;
return hypergeometric_1F0(x1, x2);
};

stan::test::expect_ad(f, 5, 0.3);
Expand Down
6 changes: 6 additions & 0 deletions test/unit/math/mix/fun/hypergeometric_pFq_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
#include <limits>

TEST(mathMixScalFun, hyper_2f2) {
using stan::math::to_array_1d;
using stan::math::to_row_vector;

auto f = [](const auto& a, const auto& b, const auto& z) {
using stan::math::hypergeometric_pFq;
return hypergeometric_pFq(a, b, z);
Expand All @@ -14,6 +17,9 @@ TEST(mathMixScalFun, hyper_2f2) {
double z = 4;

stan::test::expect_ad(f, in1, in2, z);
stan::test::expect_ad(f, to_array_1d(in1), to_row_vector(in2), z);
stan::test::expect_ad(f, to_row_vector(in1), to_array_1d(in2), z);
stan::test::expect_ad(f, to_array_1d(in1), to_array_1d(in2), z);
}

TEST(mathMixScalFun, hyper_2f3) {
Expand Down
24 changes: 12 additions & 12 deletions test/unit/math/prim/fun/hypergeometric_1F0_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,21 @@
#include <cmath>
#include <limits>

TEST(MathFunctions, hypergeometric_1f0Double) {
using stan::math::hypergeometric_1f0;
TEST(MathFunctions, hypergeometric_1F0Double) {
using stan::math::hypergeometric_1F0;
using stan::math::inv;

EXPECT_FLOAT_EQ(4.62962962963, hypergeometric_1f0(3, 0.4));
EXPECT_FLOAT_EQ(0.510204081633, hypergeometric_1f0(2, -0.4));
EXPECT_FLOAT_EQ(300.906354890, hypergeometric_1f0(16.0, 0.3));
EXPECT_FLOAT_EQ(0.531441, hypergeometric_1f0(-6.0, 0.1));
EXPECT_FLOAT_EQ(4.62962962963, hypergeometric_1F0(3, 0.4));
EXPECT_FLOAT_EQ(0.510204081633, hypergeometric_1F0(2, -0.4));
EXPECT_FLOAT_EQ(300.906354890, hypergeometric_1F0(16.0, 0.3));
EXPECT_FLOAT_EQ(0.531441, hypergeometric_1F0(-6.0, 0.1));
}

TEST(MathFunctions, hypergeometric_1f0_throw) {
using stan::math::hypergeometric_1f0;
TEST(MathFunctions, hypergeometric_1F0_throw) {
using stan::math::hypergeometric_1F0;

EXPECT_THROW(hypergeometric_1f0(2.1, 1.0), std::domain_error);
EXPECT_THROW(hypergeometric_1f0(0.5, 1.5), std::domain_error);
EXPECT_THROW(hypergeometric_1f0(0.5, -1.0), std::domain_error);
EXPECT_THROW(hypergeometric_1f0(0.5, -1.5), std::domain_error);
EXPECT_THROW(hypergeometric_1F0(2.1, 1.0), std::domain_error);
EXPECT_THROW(hypergeometric_1F0(0.5, 1.5), std::domain_error);
EXPECT_THROW(hypergeometric_1F0(0.5, -1.0), std::domain_error);
EXPECT_THROW(hypergeometric_1F0(0.5, -1.5), std::domain_error);
}
12 changes: 10 additions & 2 deletions test/unit/math/prim/fun/hypergeometric_3F2_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,16 @@

// converge
TEST(MathPrimScalFun, F32_converges_by_z) {
EXPECT_NEAR(2.5,
stan::math::hypergeometric_3F2({1.0, 1.0, 1.0}, {1.0, 1.0}, 0.6),
using stan::math::hypergeometric_3F2;
using stan::math::to_row_vector;
using stan::math::to_vector;
std::vector<double> a = {1.0, 1.0, 1.0};
std::vector<double> b = {1.0, 1.0};
double z = 0.6;

EXPECT_NEAR(2.5, hypergeometric_3F2(a, b, z), 1e-8);
EXPECT_NEAR(2.5, hypergeometric_3F2(to_vector(a), to_vector(b), z), 1e-8);
EXPECT_NEAR(2.5, hypergeometric_3F2(to_row_vector(a), to_row_vector(b), z),
1e-8);
}
// terminate by zero numerator, no sign-flip
Expand Down
6 changes: 6 additions & 0 deletions test/unit/math/prim/fun/hypergeometric_pFq_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
TEST(MathFunctions, hypergeometric_pFq_values) {
using Eigen::VectorXd;
using stan::math::hypergeometric_pFq;
using stan::math::to_array_1d;
using stan::math::to_row_vector;

VectorXd a(2);
VectorXd b(2);
Expand All @@ -12,6 +14,10 @@ TEST(MathFunctions, hypergeometric_pFq_values) {
double z = 2;

EXPECT_FLOAT_EQ(3.8420514314107791, hypergeometric_pFq(a, b, z));
EXPECT_FLOAT_EQ(3.8420514314107791,
hypergeometric_pFq(to_row_vector(a), to_row_vector(b), z));
EXPECT_FLOAT_EQ(3.8420514314107791,
hypergeometric_pFq(to_array_1d(a), to_array_1d(b), z));

a << 6, 4;
b << 3, 1;
Expand Down