Skip to content

Commit b6c6dea

Browse files
authored
Merge pull request #207 from martinRenou/add_where_function
Add where function
2 parents 1ca0a34 + 0226c15 commit b6c6dea

3 files changed

Lines changed: 33 additions & 3 deletions

File tree

docs/source/xarray.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ where ``condition`` is falsy, and it does not evaluate ``b`` where ``condition``
126126
| Python 3 - xarray | C++ 14 - xframe |
127127
+===============================================+===============================================+
128128
| ``xr.where(a > 5, a, b)`` | ``xf::where(a > 5, a, b)`` |
129+
| ``xr.where(a > 5, 100, a)`` | ``xf::where(a > 5, 100, a)`` |
129130
+-----------------------------------------------+-----------------------------------------------+
130131
| ``np.any(a)`` | ``xf::any(a)`` |
131132
+-----------------------------------------------+-----------------------------------------------+
@@ -251,4 +252,3 @@ xframe universal functions are provided for a large set number of mathematical f
251252
+-----------------------------------------------+-----------------------------------------------+
252253
| ``scipy.special.gammaln(a)`` | ``xf::lgamma(a)`` |
253254
+-----------------------------------------------+-----------------------------------------------+
254-

include/xframe/xaxis_math.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ namespace xf
114114
// Needs a fix in xtensor
115115
/*using xt::isclose;
116116
using xt::allclose;*/
117+
118+
using xt::where;
117119
}
118120

119121
#endif

test/test_xvariable_math.cpp

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@
77
****************************************************************************/
88

99
#include <algorithm>
10+
1011
#include "gtest/gtest.h"
12+
13+
#include "xtensor/xarray.hpp"
14+
#include "xtensor/xoptional_assembly.hpp"
15+
1116
#include "test_fixture.hpp"
1217

1318
namespace xf
@@ -278,7 +283,7 @@ namespace xf
278283

279284
EXPECT_EQ(fma(sa, sb, a.select(sel)), xf::fma(sa, sb, a).select(sel));
280285
}
281-
286+
282287
TEST(xvariable_math, fmax)
283288
{
284289
variable_type a = make_test_variable();
@@ -657,11 +662,34 @@ namespace xf
657662
EXPECT_EQ(isnan(a.select(sel)), xf::isnan(a).select(sel));
658663
}
659664

665+
TEST(xvariable_math, where)
666+
{
667+
auto missing = xtl::missing<double>();
668+
using data_type = xt::xoptional_assembly<xt::xarray<double>, xt::xarray<bool>>;
669+
670+
variable_type a = make_test_variable();
671+
variable_type b = make_test_variable2();
672+
673+
variable_type res = where(a < 6, b, a);
674+
675+
data_type expected = {{{ 1., 2., missing},
676+
{missing, missing, missing}},
677+
{{ 7., 7., 7.},
678+
{ 9., 9., 9.}}};
679+
EXPECT_EQ(res.data(), expected);
680+
681+
variable_type res2 = where(a < 6, 0., a);
682+
data_type expected2 = {{ 0, 0, missing},
683+
{missing, 0, 6},
684+
{ 7, 8, 9}};
685+
EXPECT_EQ(res2.data(), expected2);
686+
}
687+
660688
// Needs a fix in xtensor
661689
/*TEST(xvariable_math, isclose)
662690
{
663691
variable_type a = make_test_variable();
664692
dict_type sel = make_selector_aa();
665693
EXPECT_TRUE(isclose(a, a).select(sel));
666694
}*/
667-
}
695+
}

0 commit comments

Comments
 (0)