Skip to content

Commit 935c649

Browse files
committed
Added xt::convolve
1 parent f49d573 commit 935c649

2 files changed

Lines changed: 111 additions & 0 deletions

File tree

include/xtensor/xmath.hpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3013,6 +3013,94 @@ namespace detail {
30133013
return cov(eval(stack(xtuple(x, y))));
30143014
}
30153015
}
3016+
3017+
3018+
3019+
/*
3020+
* convolution mode placeholders for selecting the algorithm
3021+
* used in computing a 1D convolution.
3022+
* Same as NumPy's mode parameter.
3023+
*/
3024+
namespace convolve_mode
3025+
{
3026+
struct valid{};
3027+
struct full{};
3028+
}
3029+
3030+
namespace detail {
3031+
template <class E1, class E2>
3032+
inline auto convolve_impl(E1&& e1, E2&& e2, convolve_mode::valid)
3033+
{
3034+
using value_type = typename std::decay<E1>::type::value_type;
3035+
3036+
size_t const na = e1.size();
3037+
size_t const nv = e2.size();
3038+
size_t const n = na - nv + 1;
3039+
xt::xtensor<value_type, 1> out = xt::zeros<value_type>({ n });
3040+
for (size_t i = 0; i < n; i++)
3041+
{
3042+
for (int j = 0; j < nv; j++)
3043+
{
3044+
out(i) += e1(j) * e2(j + i);
3045+
}
3046+
}
3047+
return out;
3048+
}
3049+
3050+
template <class E1, class E2>
3051+
inline auto convolve_impl(E1&& e1, E2&& e2, convolve_mode::full mode)
3052+
{
3053+
using value_type = typename std::decay<E1>::type::value_type;
3054+
3055+
size_t const na = e1.size();
3056+
size_t const nv = e2.size();
3057+
size_t const n = na + nv - 1;
3058+
xt::xtensor<value_type, 1> out = xt::zeros<value_type>({ n });
3059+
for (size_t i = 0; i < n; i++)
3060+
{
3061+
size_t const jmn = (i >= nv - 1) ? i - (nv - 1) : 0;
3062+
size_t const jmx = (i < na - 1) ? i : na - 1;
3063+
for (size_t j = jmn; j <= jmx; ++j)
3064+
{
3065+
out(i) += e1(j) * e2(i - j);
3066+
}
3067+
}
3068+
return out;
3069+
}
3070+
}
3071+
3072+
/*
3073+
* @brief computes the 1D convolution between two 1D expressions
3074+
*
3075+
* @param a 1D expression
3076+
* @param v 1D expression
3077+
* @param mode placeholder Select algorithm #convolve_mode
3078+
*
3079+
* @detail the algorithm convolves a with v and will incur a copy overhead
3080+
* should v be longer than a.
3081+
*/
3082+
template <class E1, class E2, class E3>
3083+
inline auto convolve(E1&& a, E2&& v, E3 mode)
3084+
{
3085+
3086+
if (a.dimension() != 1 || v.dimension() != 1)
3087+
{
3088+
XTENSOR_THROW(std::runtime_error, "Invalid dimentions convolution arguments must be 1D expressions");
3089+
}
3090+
3091+
XTENSOR_ASSERT(a.size() > 0 && v.size() > 0);
3092+
3093+
//swap them so a is always the longest one
3094+
if (a.size() < v.size())
3095+
{
3096+
return detail::convolve_impl(std::forward<E2>(v), std::forward<E1>(a), mode);
3097+
}
3098+
else
3099+
{
3100+
return detail::convolve_impl(std::forward<E1>(a), std::forward<E2>(v), mode);
3101+
}
3102+
}
30163103
}
30173104

3105+
30183106
#endif

test/test_xmath.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -915,4 +915,27 @@ namespace xt
915915

916916
EXPECT_EQ(expected, xt::cov(x, y));
917917
}
918+
919+
920+
TEST(xmath, convolve_full)
921+
{
922+
xt::xarray<double> x = { 1.0, 3.0, 1.0 };
923+
xt::xarray<double> y = { 1.0, 1.0, 1.0 };
924+
xt::xarray<double> expected = { 1, 4, 5, 4, 1 };
925+
926+
auto result = xt::convolve(x, y, xt::convolve_mode::full());
927+
928+
EXPECT_EQ(result, expected);
929+
}
930+
931+
TEST(xmath, convolve_valid)
932+
{
933+
xt::xarray<double> x = { 3.0, 1.0, 1.0 };
934+
xt::xarray<double> y = { 1.0, 1.0, 1.0 };
935+
xt::xarray<double> expected = { 5 };
936+
937+
auto result = xt::convolve(x, y, xt::convolve_mode::valid());
938+
939+
EXPECT_EQ(result, expected);
940+
}
918941
}

0 commit comments

Comments
 (0)