@@ -133,6 +133,7 @@ struct NoOpIndexer
133133 }
134134};
135135
136+ /* @brief Indexer with shape and strides arrays of same size are packed */
136137struct StridedIndexer
137138{
138139 StridedIndexer (int _nd,
@@ -143,24 +144,76 @@ struct StridedIndexer
143144 {
144145 }
145146
147+ size_t operator ()(py::ssize_t gid) const
148+ {
149+ return compute_offset (gid);
150+ }
151+
146152 size_t operator ()(size_t gid) const
153+ {
154+ return compute_offset (static_cast <py::ssize_t >(gid));
155+ }
156+
157+ private:
158+ int nd;
159+ py::ssize_t starting_offset;
160+ py::ssize_t const *shape_strides;
161+
162+ size_t compute_offset (py::ssize_t gid) const
147163 {
148164 using dpctl::tensor::strides::CIndexer_vector;
149165
150166 CIndexer_vector _ind (nd);
151167 py::ssize_t relative_offset (0 );
152168 _ind.get_displacement <const py::ssize_t *, const py::ssize_t *>(
153- static_cast <py:: ssize_t >( gid) ,
169+ gid,
154170 shape_strides, // shape ptr
155171 shape_strides + nd, // strides ptr
156172 relative_offset);
157173 return starting_offset + relative_offset;
158174 }
175+ };
176+
177+ /* @brief Indexer with shape, strides provided separately */
178+ struct UnpackedStridedIndexer
179+ {
180+ UnpackedStridedIndexer (int _nd,
181+ py::ssize_t _offset,
182+ py::ssize_t const *_shape,
183+ py::ssize_t const *_strides)
184+ : nd(_nd), starting_offset(_offset), shape(_shape), strides(_strides)
185+ {
186+ }
187+
188+ size_t operator ()(py::ssize_t gid) const
189+ {
190+ return compute_offset (gid);
191+ }
192+
193+ size_t operator ()(size_t gid) const
194+ {
195+ return compute_offset (static_cast <py::ssize_t >(gid));
196+ }
159197
160198private:
161199 int nd;
162200 py::ssize_t starting_offset;
163- py::ssize_t const *shape_strides;
201+ py::ssize_t const *shape;
202+ py::ssize_t const *strides;
203+
204+ size_t compute_offset (py::ssize_t gid) const
205+ {
206+ using dpctl::tensor::strides::CIndexer_vector;
207+
208+ CIndexer_vector _ind (nd);
209+ py::ssize_t relative_offset (0 );
210+ _ind.get_displacement <const py::ssize_t *, const py::ssize_t *>(
211+ gid,
212+ shape, // shape ptr
213+ strides, // strides ptr
214+ relative_offset);
215+ return starting_offset + relative_offset;
216+ }
164217};
165218
166219struct Strided1DIndexer
@@ -206,7 +259,8 @@ struct Strided1DCyclicIndexer
206259template <typename displacementT> struct TwoOffsets
207260{
208261 TwoOffsets () : first_offset(0 ), second_offset(0 ) {}
209- TwoOffsets (displacementT first_offset_, displacementT second_offset_)
262+ TwoOffsets (const displacementT &first_offset_,
263+ const displacementT &second_offset_)
210264 : first_offset(first_offset_), second_offset(second_offset_)
211265 {
212266 }
@@ -238,6 +292,22 @@ struct TwoOffsets_StridedIndexer
238292 }
239293
240294 TwoOffsets<py::ssize_t > operator ()(py::ssize_t gid) const
295+ {
296+ return compute_offsets (gid);
297+ }
298+
299+ TwoOffsets<py::ssize_t > operator ()(size_t gid) const
300+ {
301+ return compute_offsets (static_cast <py::ssize_t >(gid));
302+ }
303+
304+ private:
305+ int nd;
306+ py::ssize_t starting_first_offset;
307+ py::ssize_t starting_second_offset;
308+ py::ssize_t const *shape_strides;
309+
310+ TwoOffsets<py::ssize_t > compute_offsets (py::ssize_t gid) const
241311 {
242312 using dpctl::tensor::strides::CIndexer_vector;
243313
@@ -254,12 +324,6 @@ struct TwoOffsets_StridedIndexer
254324 starting_first_offset + relative_first_offset,
255325 starting_second_offset + relative_second_offset);
256326 }
257-
258- private:
259- int nd;
260- py::ssize_t starting_first_offset;
261- py::ssize_t starting_second_offset;
262- py::ssize_t const *shape_strides;
263327};
264328
265329struct TwoZeroOffsets_Indexer
@@ -272,12 +336,33 @@ struct TwoZeroOffsets_Indexer
272336 }
273337};
274338
339+ template <typename FirstIndexerT, typename SecondIndexerT>
340+ struct TwoOffsets_CombinedIndexer
341+ {
342+ private:
343+ FirstIndexerT first_indexer_;
344+ SecondIndexerT second_indexer_;
345+
346+ public:
347+ TwoOffsets_CombinedIndexer (const FirstIndexerT &first_indexer,
348+ const SecondIndexerT &second_indexer)
349+ : first_indexer_(first_indexer), second_indexer_(second_indexer)
350+ {
351+ }
352+
353+ TwoOffsets<py::ssize_t > operator ()(py::ssize_t gid) const
354+ {
355+ return TwoOffsets<py::ssize_t >(first_indexer_ (gid),
356+ second_indexer_ (gid));
357+ }
358+ };
359+
275360template <typename displacementT> struct ThreeOffsets
276361{
277362 ThreeOffsets () : first_offset(0 ), second_offset(0 ), third_offset(0 ) {}
278- ThreeOffsets (displacementT first_offset_,
279- displacementT second_offset_,
280- displacementT third_offset_)
363+ ThreeOffsets (const displacementT & first_offset_,
364+ const displacementT & second_offset_,
365+ const displacementT & third_offset_)
281366 : first_offset(first_offset_), second_offset(second_offset_),
282367 third_offset (third_offset_)
283368 {
@@ -317,6 +402,23 @@ struct ThreeOffsets_StridedIndexer
317402 }
318403
319404 ThreeOffsets<py::ssize_t > operator ()(py::ssize_t gid) const
405+ {
406+ return compute_offsets (gid);
407+ }
408+
409+ ThreeOffsets<py::ssize_t > operator ()(size_t gid) const
410+ {
411+ return compute_offsets (static_cast <py::ssize_t >(gid));
412+ }
413+
414+ private:
415+ int nd;
416+ py::ssize_t starting_first_offset;
417+ py::ssize_t starting_second_offset;
418+ py::ssize_t starting_third_offset;
419+ py::ssize_t const *shape_strides;
420+
421+ ThreeOffsets<py::ssize_t > compute_offsets (py::ssize_t gid) const
320422 {
321423 using dpctl::tensor::strides::CIndexer_vector;
322424
@@ -337,13 +439,6 @@ struct ThreeOffsets_StridedIndexer
337439 starting_second_offset + relative_second_offset,
338440 starting_third_offset + relative_third_offset);
339441 }
340-
341- private:
342- int nd;
343- py::ssize_t starting_first_offset;
344- py::ssize_t starting_second_offset;
345- py::ssize_t starting_third_offset;
346- py::ssize_t const *shape_strides;
347442};
348443
349444struct ThreeZeroOffsets_Indexer
0 commit comments