Skip to content

Commit 69eaac5

Browse files
authored
Merge pull request #2746 from spectre-ns/reshape
[WIP] Make reshape_view accept -1 as a wildcard dimension
2 parents 5f49f64 + aaa819e commit 69eaac5

2 files changed

Lines changed: 77 additions & 15 deletions

File tree

include/xtensor/xstrided_view.hpp

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -807,6 +807,53 @@ namespace xt
807807
);
808808
}
809809

810+
namespace detail
811+
{
812+
template <typename S>
813+
struct rebind_shape;
814+
815+
template <std::size_t... X>
816+
struct rebind_shape<xt::fixed_shape<X...>>
817+
{
818+
using type = xt::fixed_shape<X...>;
819+
};
820+
821+
template <class S>
822+
struct rebind_shape
823+
{
824+
using type = rebind_container_t<size_t, S>;
825+
};
826+
827+
template <
828+
class S,
829+
std::enable_if_t<std::is_signed<get_value_type_t<typename std::decay<S>::type>>::value, bool> = true>
830+
inline void recalculate_shape_impl(S& shape, size_t size)
831+
{
832+
using value_type = get_value_type_t<typename std::decay_t<S>>;
833+
XTENSOR_ASSERT(std::count(shape.cbegin(), shape.cend(), -1) <= 1);
834+
auto iter = std::find(shape.begin(), shape.end(), -1);
835+
if (iter != std::end(shape))
836+
{
837+
const auto total = std::accumulate(shape.cbegin(), shape.cend(), -1, std::multiplies<int>{});
838+
const auto missing_dimension = size / total;
839+
(*iter) = static_cast<value_type>(missing_dimension);
840+
}
841+
}
842+
843+
template <
844+
class S,
845+
std::enable_if_t<!std::is_signed<get_value_type_t<typename std::decay<S>::type>>::value, bool> = true>
846+
inline void recalculate_shape_impl(S&, size_t)
847+
{
848+
}
849+
850+
template <class S>
851+
inline auto recalculate_shape(S&& shape, size_t size)
852+
{
853+
return recalculate_shape_impl(shape, size);
854+
}
855+
}
856+
810857
template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL, class E, class S>
811858
inline auto reshape_view(E&& e, S&& shape)
812859
{
@@ -815,18 +862,26 @@ namespace xt
815862
"traversal has to be row or column major"
816863
);
817864

818-
using shape_type = std::decay_t<S>;
819-
get_strides_t<shape_type> strides;
865+
using shape_type = std::decay_t<decltype(shape)>;
866+
using unsigned_shape_type = typename detail::rebind_shape<shape_type>::type;
867+
get_strides_t<unsigned_shape_type> strides;
820868

869+
detail::recalculate_shape(shape, e.size());
821870
xt::resize_container(strides, shape.size());
822871
compute_strides(shape, L, strides);
823872
constexpr auto computed_layout = std::decay_t<E>::static_layout == L ? L : layout_type::dynamic;
824873
using view_type = xstrided_view<
825874
xclosure_t<E>,
826-
shape_type,
875+
unsigned_shape_type,
827876
computed_layout,
828877
detail::flat_adaptor_getter<xclosure_t<E>, L>>;
829-
return view_type(std::forward<E>(e), std::forward<S>(shape), std::move(strides), 0, e.layout());
878+
return view_type(
879+
std::forward<E>(e),
880+
xtl::forward_sequence<unsigned_shape_type, S>(shape),
881+
std::move(strides),
882+
0,
883+
e.layout()
884+
);
830885
}
831886

832887
/**
@@ -858,7 +913,7 @@ namespace xt
858913
template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL, class E, class I, std::size_t N>
859914
inline auto reshape_view(E&& e, const I (&shape)[N])
860915
{
861-
using shape_type = std::array<std::size_t, N>;
916+
using shape_type = std::array<I, N>;
862917
return reshape_view<L>(std::forward<E>(e), xtl::forward_sequence<shape_type, decltype(shape)>(shape));
863918
}
864919
}

test/test_xstrided_view.cpp

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -696,24 +696,31 @@ namespace xt
696696
EXPECT_EQ(av, e);
697697
EXPECT_EQ(av, a);
698698

699-
bool truthy;
700-
truthy = std::is_same<
701-
typename decltype(xv)::temporary_type,
702-
xtensor_fixed<double, xshape<3, 3>, XTENSOR_DEFAULT_LAYOUT>>();
703-
EXPECT_TRUE(truthy);
704-
705-
truthy = std::is_same<typename decltype(av)::temporary_type, xtensor<double, 2, XTENSOR_DEFAULT_LAYOUT>>(
699+
static_assert(
700+
std::is_same<
701+
typename decltype(xv)::temporary_type,
702+
xtensor_fixed<double, xshape<3, 3>, XTENSOR_DEFAULT_LAYOUT>>::value,
703+
"Container types do not match"
704+
);
705+
static_assert(
706+
std::is_same<typename decltype(av)::temporary_type, xtensor<double, 2, XTENSOR_DEFAULT_LAYOUT>>::value,
707+
"Container types do not match"
708+
);
709+
static_assert(
710+
std::is_same<typename decltype(av)::shape_type, typename decltype(e)::shape_type>::value,
711+
"Shape types do not match"
706712
);
707-
EXPECT_TRUE(truthy);
708-
truthy = std::is_same<typename decltype(av)::shape_type, typename decltype(e)::shape_type>::value;
709-
EXPECT_TRUE(truthy);
710713

711714
xarray<int> xa = {{1, 2, 3}, {4, 5, 6}};
712715
std::vector<std::size_t> new_shape = {3, 2};
713716
auto xrv = reshape_view(xa, new_shape);
714717

715718
xarray<int> xres = {{1, 2}, {3, 4}, {5, 6}};
716719
EXPECT_EQ(xrv, xres);
720+
721+
auto nv = xt::reshape_view<XTENSOR_DEFAULT_LAYOUT>(a, {-1, 3});
722+
std::vector<size_t> expected_shape({3, 3});
723+
EXPECT_TRUE(std::equal(nv.shape().begin(), nv.shape().end(), expected_shape.begin()));
717724
}
718725

719726
TEST(xstrided_view, reshape_view_assign)

0 commit comments

Comments
 (0)