Skip to content

Commit 0717118

Browse files
authored
Merge pull request #3160 from stan-dev/fix/jachymb-inv_logistic
Use logistic function from eigen (based on jachymb's PR)
2 parents 8b8057a + b92cf29 commit 0717118

File tree

3 files changed

+71
-27
lines changed

3 files changed

+71
-27
lines changed

stan/math/prim/fun/inv_logit.hpp

+34-17
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,14 @@ namespace math {
4949
* @return Inverse logit of argument.
5050
*/
5151
inline double inv_logit(double a) {
52-
using std::exp;
5352
if (a < 0) {
54-
double exp_a = exp(a);
53+
double exp_a = std::exp(a);
5554
if (a < LOG_EPSILON) {
5655
return exp_a;
5756
}
58-
return exp_a / (1 + exp_a);
57+
return exp_a / (1.0 + exp_a);
5958
}
60-
return inv(1 + exp(-a));
59+
return inv(1 + std::exp(-a));
6160
}
6261

6362
/**
@@ -69,28 +68,46 @@ inline double inv_logit(double a) {
6968
*/
7069
struct inv_logit_fun {
7170
template <typename T>
72-
static inline auto fun(const T& x) {
73-
return inv_logit(x);
71+
static inline auto fun(T&& x) {
72+
return inv_logit(std::forward<T>(x));
7473
}
7574
};
7675

7776
/**
78-
* Vectorized version of inv_logit().
77+
* Vectorized version of inv_logit() for containers containing ad types.
7978
*
80-
* @tparam T type of container
81-
* @param x container
79+
* @tparam T type of std::vector
80+
* @param x std::vector
8281
* @return Inverse logit applied to each value in x.
8382
*/
84-
template <
85-
typename T, require_not_var_matrix_t<T>* = nullptr,
86-
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<T>* = nullptr>
87-
inline auto inv_logit(const T& x) {
88-
return apply_scalar_unary<inv_logit_fun, T>::apply(x);
83+
template <typename Container, require_ad_container_t<Container>* = nullptr,
84+
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
85+
Container>* = nullptr,
86+
require_not_rev_matrix_t<Container>* = nullptr>
87+
inline auto inv_logit(Container&& x) {
88+
return apply_scalar_unary<inv_logit_fun, Container>::apply(
89+
std::forward<Container>(x));
8990
}
9091

91-
// TODO(Tadej): Eigen is introducing their implementation logistic() of this
92-
// in 3.4. Use that once we switch to Eigen 3.4
93-
92+
/**
93+
* Vectorized version of inv_logit() for containers with arithmetic scalar
94+
* types.
95+
*
96+
* @tparam T A type of either `std::vector` or a type that directly inherits
97+
* from `Eigen::DenseBase`. The inner scalar type must not have an auto diff
98+
* scalar type.
99+
* @param x Eigen expression
100+
* @return Inverse logit applied to each value in x.
101+
*/
102+
template <typename Container,
103+
require_container_bt<std::is_arithmetic, Container>* = nullptr,
104+
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
105+
Container>* = nullptr>
106+
inline auto inv_logit(Container&& x) {
107+
return apply_vector_unary<Container>::apply(
108+
std::forward<Container>(x),
109+
[](const auto& v) { return v.array().logistic(); });
110+
}
94111
} // namespace math
95112
} // namespace stan
96113

stan/math/prim/functor/apply_scalar_unary.hpp

+12-10
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ struct apply_scalar_unary<F, T, require_eigen_t<T>> {
5858
* @return Componentwise application of the function specified
5959
* by F to the specified matrix.
6060
*/
61-
static inline auto apply(const T& x) {
61+
static inline auto apply(const std::decay_t<T>& x) {
6262
return x.unaryExpr([](auto&& x) {
6363
return apply_scalar_unary<F, std::decay_t<decltype(x)>>::apply(x);
6464
});
@@ -69,7 +69,7 @@ struct apply_scalar_unary<F, T, require_eigen_t<T>> {
6969
* expression template of type T.
7070
*/
7171
using return_t = std::decay_t<decltype(
72-
apply_scalar_unary<F, T>::apply(std::declval<T>()))>;
72+
apply_scalar_unary<F, std::decay_t<T>>::apply(std::declval<T>()))>;
7373
};
7474

7575
/**
@@ -83,7 +83,8 @@ struct apply_scalar_unary<F, T, require_floating_point_t<T>> {
8383
/**
8484
* The return type, double.
8585
*/
86-
using return_t = std::decay_t<decltype(F::fun(std::declval<T>()))>;
86+
using return_t
87+
= std::decay_t<decltype(F::fun(std::declval<std::decay_t<T>>()))>;
8788

8889
/**
8990
* Apply the function specified by F to the specified argument.
@@ -114,11 +115,12 @@ struct apply_scalar_unary<F, T, require_complex_t<T>> {
114115
* @param x Argument scalar.
115116
* @return Result of applying F to the scalar.
116117
*/
117-
static inline auto apply(const T& x) { return F::fun(x); }
118+
static inline auto apply(const std::decay_t<T>& x) { return F::fun(x); }
118119
/**
119120
* The return type
120121
*/
121-
using return_t = std::decay_t<decltype(F::fun(std::declval<T>()))>;
122+
using return_t
123+
= std::decay_t<decltype(F::fun(std::declval<std::decay_t<T>>()))>;
122124
};
123125

124126
/**
@@ -157,13 +159,13 @@ struct apply_scalar_unary<F, T, require_integral_t<T>> {
157159
* @tparam T Type of element contained in standard vector.
158160
*/
159161
template <typename F, typename T>
160-
struct apply_scalar_unary<F, std::vector<T>> {
162+
struct apply_scalar_unary<F, T, require_std_vector_t<T>> {
161163
/**
162164
* Return type, which is calculated recursively as a standard
163165
* vector of the return type of the contained type T.
164166
*/
165-
using return_t = typename std::vector<
166-
plain_type_t<typename apply_scalar_unary<F, T>::return_t>>;
167+
using return_t = typename std::vector<plain_type_t<
168+
typename apply_scalar_unary<F, value_type_t<std::decay_t<T>>>::return_t>>;
167169

168170
/**
169171
* Apply the function specified by F elementwise to the
@@ -174,10 +176,10 @@ struct apply_scalar_unary<F, std::vector<T>> {
174176
* @return Elementwise application of F to the elements of the
175177
* container.
176178
*/
177-
static inline auto apply(const std::vector<T>& x) {
179+
static inline auto apply(const std::decay_t<T>& x) {
178180
return_t fx(x.size());
179181
for (size_t i = 0; i < x.size(); ++i) {
180-
fx[i] = apply_scalar_unary<F, T>::apply(x[i]);
182+
fx[i] = apply_scalar_unary<F, value_type_t<T>>::apply(x[i]);
181183
}
182184
return fx;
183185
}

stan/math/rev/fun/inv_logit.hpp

+25
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,31 @@ inline auto inv_logit(const var_value<T>& a) {
3131
});
3232
}
3333

34+
/**
35+
* The inverse logit function for Eigen expressions with var value type.
36+
*
37+
* See inv_logit() for the double-based version.
38+
*
39+
* The derivative of inverse logit is
40+
*
41+
* \f$\frac{d}{dx} \mbox{logit}^{-1}(x) = \mbox{logit}^{-1}(x) (1 -
42+
* \mbox{logit}^{-1}(x))\f$.
43+
*
44+
* @tparam T type of Eigen expression
45+
* @param x Eigen expression
46+
* @return Inverse logit of argument.
47+
*/
48+
template <typename T, require_eigen_vt<is_var, T>* = nullptr>
49+
inline auto inv_logit(T&& x) {
50+
auto x_arena = to_arena(std::forward<T>(x));
51+
arena_t<T> ret = inv_logit(x_arena.val());
52+
reverse_pass_callback([x_arena, ret]() mutable {
53+
x_arena.adj().array()
54+
+= ret.adj().array() * ret.val().array() * (1.0 - ret.val().array());
55+
});
56+
return ret;
57+
}
58+
3459
} // namespace math
3560
} // namespace stan
3661
#endif

0 commit comments

Comments
 (0)