@@ -807,6 +807,53 @@ namespace xt
807807 );
808808 }
809809
810+ namespace detail
811+ {
812+ template <typename S>
813+ struct rebind_shape ;
814+
815+ template <std::size_t ... X>
816+ struct rebind_shape <xt::fixed_shape<X...>>
817+ {
818+ using type = xt::fixed_shape<X...>;
819+ };
820+
821+ template <class S >
822+ struct rebind_shape
823+ {
824+ using type = rebind_container_t <size_t , S>;
825+ };
826+
827+ template <
828+ class S ,
829+ std::enable_if_t <std::is_signed<get_value_type_t <typename std::decay<S>::type>>::value, bool > = true >
830+ inline void recalculate_shape_impl (S& shape, size_t size)
831+ {
832+ using value_type = get_value_type_t <typename std::decay_t <S>>;
833+ XTENSOR_ASSERT (std::count (shape.cbegin (), shape.cend (), -1 ) <= 1 );
834+ auto iter = std::find (shape.begin (), shape.end (), -1 );
835+ if (iter != std::end (shape))
836+ {
837+ const auto total = std::accumulate (shape.cbegin (), shape.cend (), -1 , std::multiplies<int >{});
838+ const auto missing_dimension = size / total;
839+ (*iter) = static_cast <value_type>(missing_dimension);
840+ }
841+ }
842+
843+ template <
844+ class S ,
845+ std::enable_if_t <!std::is_signed<get_value_type_t <typename std::decay<S>::type>>::value, bool > = true >
846+ inline void recalculate_shape_impl (S&, size_t )
847+ {
848+ }
849+
850+ template <class S >
851+ inline auto recalculate_shape (S&& shape, size_t size)
852+ {
853+ return recalculate_shape_impl (shape, size);
854+ }
855+ }
856+
810857 template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL, class E , class S >
811858 inline auto reshape_view (E&& e, S&& shape)
812859 {
@@ -815,18 +862,26 @@ namespace xt
815862 " traversal has to be row or column major"
816863 );
817864
818- using shape_type = std::decay_t <S>;
819- get_strides_t <shape_type> strides;
865+ using shape_type = std::decay_t <decltype (shape)>;
866+ using unsigned_shape_type = typename detail::rebind_shape<shape_type>::type;
867+ get_strides_t <unsigned_shape_type> strides;
820868
869+ detail::recalculate_shape (shape, e.size ());
821870 xt::resize_container (strides, shape.size ());
822871 compute_strides (shape, L, strides);
823872 constexpr auto computed_layout = std::decay_t <E>::static_layout == L ? L : layout_type::dynamic;
824873 using view_type = xstrided_view<
825874 xclosure_t <E>,
826- shape_type ,
875+ unsigned_shape_type ,
827876 computed_layout,
828877 detail::flat_adaptor_getter<xclosure_t <E>, L>>;
829- return view_type (std::forward<E>(e), std::forward<S>(shape), std::move (strides), 0 , e.layout ());
878+ return view_type (
879+ std::forward<E>(e),
880+ xtl::forward_sequence<unsigned_shape_type, S>(shape),
881+ std::move (strides),
882+ 0 ,
883+ e.layout ()
884+ );
830885 }
831886
832887 /* *
@@ -858,7 +913,7 @@ namespace xt
858913 template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL, class E , class I , std::size_t N>
859914 inline auto reshape_view (E&& e, const I (&shape)[N])
860915 {
861- using shape_type = std::array<std:: size_t , N>;
916+ using shape_type = std::array<I , N>;
862917 return reshape_view<L>(std::forward<E>(e), xtl::forward_sequence<shape_type, decltype (shape)>(shape));
863918 }
864919}
0 commit comments