@@ -46,51 +46,46 @@ using namespace dpctl::tensor::offset_utils;
4646
4747template <typename srcT, typename dstT, typename IndexerT>
4848class copy_cast_generic_kernel ;
49+
4950template <typename srcT, typename dstT, typename IndexerT>
5051class copy_cast_from_host_kernel ;
51- // template <typename srcT, typename dstT, typename IndexerT>
52- // class copy_cast_spec_kernel;
52+
5353template <typename Ty, typename SrcIndexerT, typename DstIndexerT>
5454class copy_for_reshape_generic_kernel ;
5555
56- template <typename srcT , typename dstT > class Caster
56+ template <typename srcTy , typename dstTy > class Caster
5757{
5858public:
5959 Caster () = default ;
60- void operator ()(const char *src,
61- std::ptrdiff_t src_offset,
62- char *dst,
63- std::ptrdiff_t dst_offset) const
60+ dstTy operator ()(const srcTy &src) const
6461 {
6562 using dpctl::tensor::type_utils::convert_impl;
66-
67- const srcT *src_ = reinterpret_cast <const srcT *>(src) + src_offset;
68- dstT *dst_ = reinterpret_cast <dstT *>(dst) + dst_offset;
69- *dst_ = convert_impl<dstT, srcT>(*src_);
63+ return convert_impl<dstTy, srcTy>(src);
7064 }
7165};
7266
73- template <typename CastFnT, typename IndexerT> class GenericCopyFunctor
67+ template <typename srcT, typename dstT, typename CastFnT, typename IndexerT>
68+ class GenericCopyFunctor
7469{
7570private:
76- const char *src_ = nullptr ;
77- char *dst_ = nullptr ;
71+ const srcT *src_ = nullptr ;
72+ dstT *dst_ = nullptr ;
7873 IndexerT indexer_;
7974
8075public:
81- GenericCopyFunctor (const char *src_cp, char *dst_cp , IndexerT indexer)
82- : src_(src_cp ), dst_(dst_cp ), indexer_(indexer)
76+ GenericCopyFunctor (const srcT *src_p, dstT *dst_p , IndexerT indexer)
77+ : src_(src_p ), dst_(dst_p ), indexer_(indexer)
8378 {
8479 }
8580
8681 void operator ()(sycl::id<1 > wiid) const
8782 {
88- auto offsets = indexer_ (static_cast <py::ssize_t >(wiid.get (0 )));
89- py::ssize_t src_offset = offsets.get_first_offset ();
90- py::ssize_t dst_offset = offsets.get_second_offset ();
83+ const auto & offsets = indexer_ (static_cast <py::ssize_t >(wiid.get (0 )));
84+ const py::ssize_t & src_offset = offsets.get_first_offset ();
85+ const py::ssize_t & dst_offset = offsets.get_second_offset ();
9186
9287 CastFnT fn{};
93- fn (src_, src_offset, dst_, dst_offset );
88+ dst_[dst_offset] = fn (src_[ src_offset] );
9489 }
9590};
9691
@@ -168,12 +163,15 @@ copy_and_cast_generic_impl(sycl::queue q,
168163
169164 TwoOffsets_StridedIndexer indexer{nd, src_offset, dst_offset,
170165 shape_and_strides};
166+ const srcTy *src_tp = reinterpret_cast <const srcTy *>(src_p);
167+ dstTy *dst_tp = reinterpret_cast <dstTy *>(dst_p);
171168
172169 cgh.parallel_for <class copy_cast_generic_kernel <
173170 srcTy, dstTy, TwoOffsets_StridedIndexer>>(
174171 sycl::range<1 >(nelems),
175- GenericCopyFunctor<Caster<srcTy, dstTy>, TwoOffsets_StridedIndexer>(
176- src_p, dst_p, indexer));
172+ GenericCopyFunctor<srcTy, dstTy, Caster<srcTy, dstTy>,
173+ TwoOffsets_StridedIndexer>(src_tp, dst_tp,
174+ indexer));
177175 });
178176
179177 return copy_and_cast_ev;
@@ -276,13 +274,15 @@ copy_and_cast_nd_specialized_impl(sycl::queue q,
276274 using IndexerT = TwoOffsets_FixedDimStridedIndexer<nd>;
277275 IndexerT indexer{shape, src_strides, dst_strides, src_offset,
278276 dst_offset};
277+ const srcTy *src_tp = reinterpret_cast <const srcTy *>(src_p);
278+ dstTy *dst_tp = reinterpret_cast <dstTy *>(dst_p);
279279
280280 cgh.depends_on (depends);
281281 cgh.parallel_for <
282282 class copy_cast_generic_kernel <srcTy, dstTy, IndexerT>>(
283283 sycl::range<1 >(nelems),
284- GenericCopyFunctor<Caster<srcTy, dstTy>, IndexerT>(src_p, dst_p,
285- indexer));
284+ GenericCopyFunctor<srcTy, dstTy, Caster<srcTy, dstTy>, IndexerT>(
285+ src_tp, dst_tp, indexer));
286286 });
287287
288288 return copy_and_cast_ev;
@@ -318,46 +318,33 @@ template <typename fnT, typename D, typename S> struct CopyAndCast2DFactory
318318
319319// ====================== Copying from host to USM
320320
321- template <typename srcT, typename dstT, typename AccessorT>
322- class CasterForAccessor
323- {
324- public:
325- CasterForAccessor () = default ;
326- void operator ()(AccessorT src,
327- std::ptrdiff_t src_offset,
328- char *dst,
329- std::ptrdiff_t dst_offset) const
330- {
331- using dpctl::tensor::type_utils::convert_impl;
332-
333- dstT *dst_ = reinterpret_cast <dstT *>(dst) + dst_offset;
334- *dst_ = convert_impl<dstT, srcT>(src[src_offset]);
335- }
336- };
337-
338- template <typename CastFnT, typename AccessorT, typename IndexerT>
321+ template <typename AccessorT,
322+ typename dstTy,
323+ typename CastFnT,
324+ typename IndexerT>
339325class GenericCopyFromHostFunctor
340326{
341327private:
342328 AccessorT src_acc_;
343- char *dst_ = nullptr ;
329+ dstTy *dst_ = nullptr ;
344330 IndexerT indexer_;
345331
346332public:
347333 GenericCopyFromHostFunctor (AccessorT src_acc,
348- char *dst_cp ,
334+ dstTy *dst_p ,
349335 IndexerT indexer)
350- : src_acc_(src_acc), dst_(dst_cp ), indexer_(indexer)
336+ : src_acc_(src_acc), dst_(dst_p ), indexer_(indexer)
351337 {
352338 }
353339
354340 void operator ()(sycl::id<1 > wiid) const
355341 {
356- auto offsets = indexer_ (static_cast <py::ssize_t >(wiid.get (0 )));
357- py::ssize_t src_offset = offsets.get_first_offset ();
358- py::ssize_t dst_offset = offsets.get_second_offset ();
342+ const auto &offsets = indexer_ (static_cast <py::ssize_t >(wiid.get (0 )));
343+ const py::ssize_t &src_offset = offsets.get_first_offset ();
344+ const py::ssize_t &dst_offset = offsets.get_second_offset ();
345+
359346 CastFnT fn{};
360- fn (src_acc_, src_offset, dst_, dst_offset );
347+ dst_[dst_offset] = fn (src_acc_[ src_offset] );
361348 }
362349};
363350
@@ -447,13 +434,15 @@ void copy_and_cast_from_host_impl(
447434 nd, src_offset - src_min_nelem_offset, dst_offset,
448435 const_cast <const py::ssize_t *>(shape_and_strides)};
449436
437+ dstTy *dst_tp = reinterpret_cast <dstTy *>(dst_p);
438+
450439 cgh.parallel_for <copy_cast_from_host_kernel<srcTy, dstTy,
451440 TwoOffsets_StridedIndexer>>(
452441 sycl::range<1 >(nelems),
453- GenericCopyFromHostFunctor<
454- CasterForAccessor <srcTy, dstTy, decltype (npy_acc) >,
455- decltype (npy_acc), TwoOffsets_StridedIndexer>(npy_acc, dst_p,
456- indexer));
442+ GenericCopyFromHostFunctor<decltype (npy_acc), dstTy,
443+ Caster <srcTy, dstTy>,
444+ TwoOffsets_StridedIndexer>(
445+ npy_acc, dst_tp, indexer));
457446 });
458447
459448 // perform explicit synchronization. Implicit synchronization would be
0 commit comments