Skip to content

Commit a0076ca

Browse files
authored
Merge pull request #2469 from andrjohns/feature/issue-2385-apply-unsigned
Fix apply_scalar_unary with unsigned and long arithmetic types
2 parents 6c3e3c7 + 208cab0 commit a0076ca

2 files changed

Lines changed: 14 additions & 6 deletions

File tree

stan/math/prim/functor/apply_scalar_unary.hpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <stan/math/prim/fun/Eigen.hpp>
55
#include <stan/math/prim/meta/is_eigen.hpp>
6+
#include <stan/math/prim/meta/require_generics.hpp>
67
#include <stan/math/prim/meta/is_vector.hpp>
78
#include <stan/math/prim/meta/is_vector_like.hpp>
89
#include <stan/math/prim/meta/plain_type.hpp>
@@ -80,8 +81,8 @@ struct apply_scalar_unary<F, T, require_eigen_t<T>> {
8081
*
8182
* @tparam F Type of function defining static apply function.
8283
*/
83-
template <typename F>
84-
struct apply_scalar_unary<F, double> {
84+
template <typename F, typename T>
85+
struct apply_scalar_unary<F, T, require_floating_point_t<T>> {
8586
/**
8687
* The return type, double.
8788
*/
@@ -96,7 +97,7 @@ struct apply_scalar_unary<F, double> {
9697
* @param x Argument scalar.
9798
* @return Result of applying F to the scalar.
9899
*/
99-
static inline return_t apply(double x) { return F::fun(x); }
100+
static inline return_t apply(T x) { return F::fun(x); }
100101
};
101102

102103
/**
@@ -107,8 +108,8 @@ struct apply_scalar_unary<F, double> {
107108
*
108109
* @tparam F Type of function defining static apply function.
109110
*/
110-
template <typename F>
111-
struct apply_scalar_unary<F, int> {
111+
template <typename F, typename T>
112+
struct apply_scalar_unary<F, T, require_integral_t<T>> {
112113
/**
113114
* The return type, double.
114115
*/
@@ -123,7 +124,7 @@ struct apply_scalar_unary<F, int> {
123124
* @param x Argument scalar.
124125
* @return Result of applying F to the scalar.
125126
*/
126-
static inline return_t apply(int x) { return F::fun(static_cast<double>(x)); }
127+
static inline return_t apply(T x) { return F::fun(static_cast<double>(x)); }
127128
};
128129

129130
/**

test/unit/math/prim/fun/sqrt_test.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,11 @@ TEST(MathFunctions, sqrtInt) {
66
EXPECT_FLOAT_EQ(std::sqrt(3.0), sqrt(3));
77
EXPECT_FLOAT_EQ(std::sqrt(3.1), sqrt(3.1));
88
EXPECT_TRUE(stan::math::is_nan(sqrt(-2)));
9+
10+
uint32_t ulong = 1;
11+
uint64_t ulonglong = 1;
12+
long double ldouble = 1.5;
13+
EXPECT_FLOAT_EQ(std::sqrt(ulong), sqrt(ulong));
14+
EXPECT_FLOAT_EQ(std::sqrt(ulonglong), sqrt(ulonglong));
15+
EXPECT_FLOAT_EQ(std::sqrt(ldouble), sqrt(ldouble));
916
}

0 commit comments

Comments
 (0)