Skip to content

Commit e2ffd62

Browse files
committed
Specialize operator= when RHS is chunked
1 parent a1f6b16 commit e2ffd62

1 file changed

Lines changed: 44 additions & 2 deletions

File tree

include/xtensor/xchunked_view.hpp

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,21 @@
1919
namespace xt
2020
{
2121

22+
// SFINAE test if chunked
23+
template <typename T>
24+
class has_chunks
25+
{
26+
private:
27+
typedef char YesType[1];
28+
typedef char NoType[2];
29+
30+
template <typename C> static YesType& test(decltype(&C::chunk_shape));
31+
template <typename C> static NoType& test(...);
32+
33+
public:
34+
enum { value = sizeof(test<T>(0)) == sizeof(YesType) };
35+
};
36+
2237
/*****************
2338
* xchunked_view *
2439
*****************/
@@ -48,7 +63,10 @@ namespace xt
4863
xchunked_view(OE&& e, S&& chunk_shape);
4964

5065
template <class OE>
51-
xchunked_view<E>& operator=(const OE& e);
66+
typename std::enable_if<!has_chunks<OE>::value, xchunked_view<E>&>::type operator=(const OE& e);
67+
68+
template <class OE>
69+
typename std::enable_if<has_chunks<OE>::value, xchunked_view<E>&>::type operator=(const OE& e);
5270

5371
size_type dimension() const noexcept;
5472
const shape_type& shape() const noexcept;
@@ -114,7 +132,7 @@ namespace xt
114132

115133
template <class E>
116134
template <class OE>
117-
xchunked_view<E>& xchunked_view<E>::operator=(const OE& e)
135+
typename std::enable_if<!has_chunks<OE>::value, xchunked_view<E>&>::type xchunked_view<E>::operator=(const OE& e)
118136
{
119137
for (auto it = chunk_begin(); it != chunk_end(); it++)
120138
{
@@ -124,6 +142,30 @@ namespace xt
124142
return *this;
125143
}
126144

145+
template <class E>
146+
template <class OE>
147+
typename std::enable_if<has_chunks<OE>::value, xchunked_view<E>&>::type xchunked_view<E>::operator=(const OE& e)
148+
{
149+
for (auto it1 = chunk_begin(), it2 = e.chunks().begin(); it1 != chunk_end(); it1++, it2++)
150+
{
151+
auto el1 = *it1;
152+
auto el2 = *it2;
153+
auto lhs_shape = el1.shape();
154+
if (lhs_shape != el2.shape())
155+
{
156+
xstrided_slice_vector esv(el2.dimension()); // element slice in edge chunk
157+
std::transform(lhs_shape.begin(), lhs_shape.end(), esv.begin(),
158+
[](auto size) { return range(0, size); });
159+
noalias(el1) = strided_view(el2, esv);
160+
}
161+
else
162+
{
163+
noalias(el1) = el2;
164+
}
165+
}
166+
return *this;
167+
}
168+
127169
template <class E>
128170
inline auto xchunked_view<E>::dimension() const noexcept -> size_type
129171
{

0 commit comments

Comments
 (0)