Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use logistic function from eigen (based on jachymb's PR) #3160

Merged
merged 10 commits into from
Mar 14, 2025
51 changes: 34 additions & 17 deletions stan/math/prim/fun/inv_logit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,14 @@ namespace math {
* @return Inverse logit of argument.
*/
inline double inv_logit(double a) {
using std::exp;
if (a < 0) {
double exp_a = exp(a);
double exp_a = std::exp(a);
if (a < LOG_EPSILON) {
return exp_a;
}
return exp_a / (1 + exp_a);
return exp_a / (1.0 + exp_a);
}
return inv(1 + exp(-a));
return inv(1 + std::exp(-a));
}

/**
Expand All @@ -69,28 +68,46 @@ inline double inv_logit(double a) {
*/
struct inv_logit_fun {
template <typename T>
static inline auto fun(const T& x) {
return inv_logit(x);
static inline auto fun(T&& x) {
return inv_logit(std::forward<T>(x));
}
};

/**
* Vectorized version of inv_logit().
* Vectorized version of inv_logit() for containers containing ad types.
*
* @tparam T type of container
* @param x container
* @tparam T type of std::vector
* @param x std::vector
* @return Inverse logit applied to each value in x.
*/
template <
typename T, require_not_var_matrix_t<T>* = nullptr,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<T>* = nullptr>
inline auto inv_logit(const T& x) {
return apply_scalar_unary<inv_logit_fun, T>::apply(x);
template <typename Container, require_ad_container_t<Container>* = nullptr,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
Container>* = nullptr,
require_not_rev_matrix_t<Container>* = nullptr>
inline auto inv_logit(Container&& x) {
return apply_scalar_unary<inv_logit_fun, Container>::apply(
std::forward<Container>(x));
}

// TODO(Tadej): Eigen is introducing their implementation logistic() of this
// in 3.4. Use that once we switch to Eigen 3.4

/**
* Vectorized version of inv_logit() for containers with arithmetic scalar
* types.
*
* @tparam T A type of either `std::vector` or a type that directly inherits
* from `Eigen::DenseBase`. The inner scalar type must not have an auto diff
* scalar type.
* @param x Eigen expression
* @return Inverse logit applied to each value in x.
*/
template <typename Container,
require_container_bt<std::is_arithmetic, Container>* = nullptr,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
Container>* = nullptr>
inline auto inv_logit(Container&& x) {
return apply_vector_unary<Container>::apply(
std::forward<Container>(x),
[](const auto& v) { return v.array().logistic(); });
}
} // namespace math
} // namespace stan

Expand Down
22 changes: 12 additions & 10 deletions stan/math/prim/functor/apply_scalar_unary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ struct apply_scalar_unary<F, T, require_eigen_t<T>> {
* @return Componentwise application of the function specified
* by F to the specified matrix.
*/
static inline auto apply(const T& x) {
static inline auto apply(const std::decay_t<T>& x) {
return x.unaryExpr([](auto&& x) {
return apply_scalar_unary<F, std::decay_t<decltype(x)>>::apply(x);
});
Expand All @@ -69,7 +69,7 @@ struct apply_scalar_unary<F, T, require_eigen_t<T>> {
* expression template of type T.
*/
using return_t = std::decay_t<decltype(
apply_scalar_unary<F, T>::apply(std::declval<T>()))>;
apply_scalar_unary<F, std::decay_t<T>>::apply(std::declval<T>()))>;
};

/**
Expand All @@ -83,7 +83,8 @@ struct apply_scalar_unary<F, T, require_floating_point_t<T>> {
/**
* The return type, double.
*/
using return_t = std::decay_t<decltype(F::fun(std::declval<T>()))>;
using return_t
= std::decay_t<decltype(F::fun(std::declval<std::decay_t<T>>()))>;

/**
* Apply the function specified by F to the specified argument.
Expand Down Expand Up @@ -114,11 +115,12 @@ struct apply_scalar_unary<F, T, require_complex_t<T>> {
* @param x Argument scalar.
* @return Result of applying F to the scalar.
*/
static inline auto apply(const T& x) { return F::fun(x); }
static inline auto apply(const std::decay_t<T>& x) { return F::fun(x); }
/**
* The return type
*/
using return_t = std::decay_t<decltype(F::fun(std::declval<T>()))>;
using return_t
= std::decay_t<decltype(F::fun(std::declval<std::decay_t<T>>()))>;
};

/**
Expand Down Expand Up @@ -157,13 +159,13 @@ struct apply_scalar_unary<F, T, require_integral_t<T>> {
* @tparam T Type of element contained in standard vector.
*/
template <typename F, typename T>
struct apply_scalar_unary<F, std::vector<T>> {
struct apply_scalar_unary<F, T, require_std_vector_t<T>> {
/**
* Return type, which is calculated recursively as a standard
* vector of the return type of the contained type T.
*/
using return_t = typename std::vector<
plain_type_t<typename apply_scalar_unary<F, T>::return_t>>;
using return_t = typename std::vector<plain_type_t<
typename apply_scalar_unary<F, value_type_t<std::decay_t<T>>>::return_t>>;

/**
* Apply the function specified by F elementwise to the
Expand All @@ -174,10 +176,10 @@ struct apply_scalar_unary<F, std::vector<T>> {
* @return Elementwise application of F to the elements of the
* container.
*/
static inline auto apply(const std::vector<T>& x) {
static inline auto apply(const std::decay_t<T>& x) {
return_t fx(x.size());
for (size_t i = 0; i < x.size(); ++i) {
fx[i] = apply_scalar_unary<F, T>::apply(x[i]);
fx[i] = apply_scalar_unary<F, value_type_t<T>>::apply(x[i]);
}
return fx;
}
Expand Down
25 changes: 25 additions & 0 deletions stan/math/rev/fun/inv_logit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,31 @@ inline auto inv_logit(const var_value<T>& a) {
});
}

/**
* The inverse logit function for Eigen expressions with var value type.
*
* See inv_logit() for the double-based version.
*
* The derivative of inverse logit is
*
* \f$\frac{d}{dx} \mbox{logit}^{-1}(x) = \mbox{logit}^{-1}(x) (1 -
* \mbox{logit}^{-1}(x))\f$.
*
* @tparam T type of Eigen expression
* @param x Eigen expression
* @return Inverse logit of argument.
*/
template <typename T, require_eigen_vt<is_var, T>* = nullptr>
inline auto inv_logit(T&& x) {
auto x_arena = to_arena(std::forward<T>(x));
arena_t<T> ret = inv_logit(x_arena.val());
reverse_pass_callback([x_arena, ret]() mutable {
x_arena.adj().array()
+= ret.adj().array() * ret.val().array() * (1.0 - ret.val().array());
});
return ret;
}

} // namespace math
} // namespace stan
#endif