Skip to content

Commit 585f2e5

Browse files
Extend ._tensor_impl with linear sequences functions (#2782)
This PR is the final one in the series of extending `_tensor_impl` extension It extends `_tensor_impl` in `dpctl_ext.tensor` with linear sequence functions (`_linspace_step and _linspace_affine`) Also this PR significantly expands Python API of `dpctl_ext.tensor` by adding all missing functions from `dpctl_ext.tensor._ctors` and `dpctl_ext.tensor._manipulation_functions` `_tensor_impl`: 45 / 45 functions Python API dpctl_ext.tensor: 70 / 233 functions
1 parent 0edd3b1 commit 585f2e5

40 files changed

Lines changed: 3194 additions & 334 deletions

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ set(_tensor_impl_sources
5151
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp
5252
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_for_reshape.cpp
5353
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_for_roll.cpp
54-
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linear_sequences.cpp
54+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linear_sequences.cpp
5555
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/integer_advanced_indexing.cpp
5656
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/boolean_advanced_indexing.cpp
5757
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/eye_ctor.cpp
@@ -93,7 +93,7 @@ endif()
9393
set(_no_fast_math_sources
9494
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_and_cast_usm_to_usm.cpp
9595
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp
96-
# ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linear_sequences.cpp
96+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linear_sequences.cpp
9797
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/clip.cpp
9898
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/where.cpp
9999
)

dpctl_ext/tensor/__init__.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,21 @@
3636
to_numpy,
3737
)
3838
from ._ctors import (
39+
arange,
40+
asarray,
41+
empty,
42+
empty_like,
3943
eye,
4044
full,
45+
full_like,
46+
linspace,
47+
meshgrid,
48+
ones,
49+
ones_like,
4150
tril,
4251
triu,
52+
zeros,
53+
zeros_like,
4354
)
4455
from ._indexing_functions import (
4556
extract,
@@ -51,38 +62,73 @@
5162
take_along_axis,
5263
)
5364
from ._manipulation_functions import (
65+
broadcast_arrays,
66+
broadcast_to,
67+
concat,
68+
expand_dims,
69+
flip,
70+
moveaxis,
71+
permute_dims,
5472
repeat,
5573
roll,
74+
squeeze,
75+
stack,
76+
swapaxes,
77+
tile,
78+
unstack,
5679
)
5780
from ._reshape import reshape
5881
from ._search_functions import where
5982
from ._type_utils import can_cast, finfo, iinfo, isdtype, result_type
6083

6184
__all__ = [
85+
"arange",
86+
"asarray",
6287
"asnumpy",
6388
"astype",
89+
"broadcast_arrays",
90+
"broadcast_to",
6491
"can_cast",
92+
"concat",
6593
"copy",
6694
"clip",
95+
"empty",
96+
"empty_like",
6797
"extract",
98+
"expand_dims",
6899
"eye",
69100
"finfo",
101+
"flip",
70102
"from_numpy",
71103
"full",
104+
"full_like",
72105
"iinfo",
73106
"isdtype",
107+
"linspace",
108+
"meshgrid",
109+
"moveaxis",
110+
"permute_dims",
74111
"nonzero",
112+
"ones",
113+
"ones_like",
75114
"place",
76115
"put",
77116
"put_along_axis",
78117
"repeat",
79118
"reshape",
80119
"result_type",
81120
"roll",
121+
"squeeze",
122+
"stack",
123+
"swapaxes",
82124
"take",
83125
"take_along_axis",
126+
"tile",
84127
"to_numpy",
85128
"tril",
86129
"triu",
130+
"unstack",
87131
"where",
132+
"zeros",
133+
"zeros_like",
88134
]

dpctl_ext/tensor/_clip.py

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -163,20 +163,20 @@ def _clip_none(x, val, out, order, _binary_fn):
163163

164164
if ti._array_overlap(x, out):
165165
if not ti._same_logical_tensors(x, out):
166-
out = dpt.empty_like(out)
166+
out = dpt_ext.empty_like(out)
167167

168168
if isinstance(val, dpt.usm_ndarray):
169169
if (
170170
ti._array_overlap(val, out)
171171
and not ti._same_logical_tensors(val, out)
172172
and val_dtype == res_dt
173173
):
174-
out = dpt.empty_like(out)
174+
out = dpt_ext.empty_like(out)
175175

176176
if isinstance(val, dpt.usm_ndarray):
177177
val_ary = val
178178
else:
179-
val_ary = dpt.asarray(val, dtype=val_dtype, sycl_queue=exec_q)
179+
val_ary = dpt_ext.asarray(val, dtype=val_dtype, sycl_queue=exec_q)
180180

181181
if order == "A":
182182
order = (
@@ -197,17 +197,17 @@ def _clip_none(x, val, out, order, _binary_fn):
197197
x, val_ary, res_dt, res_shape, res_usm_type, exec_q
198198
)
199199
else:
200-
out = dpt.empty(
200+
out = dpt_ext.empty(
201201
res_shape,
202202
dtype=res_dt,
203203
usm_type=res_usm_type,
204204
sycl_queue=exec_q,
205205
order=order,
206206
)
207207
if x_shape != res_shape:
208-
x = dpt.broadcast_to(x, res_shape)
208+
x = dpt_ext.broadcast_to(x, res_shape)
209209
if val_ary.shape != res_shape:
210-
val_ary = dpt.broadcast_to(val_ary, res_shape)
210+
val_ary = dpt_ext.broadcast_to(val_ary, res_shape)
211211
_manager = SequentialOrderManager[exec_q]
212212
dep_evs = _manager.submitted_events
213213
ht_binary_ev, binary_ev = _binary_fn(
@@ -229,7 +229,7 @@ def _clip_none(x, val, out, order, _binary_fn):
229229
if order == "K":
230230
buf = _empty_like_orderK(val_ary, res_dt)
231231
else:
232-
buf = dpt.empty_like(val_ary, dtype=res_dt, order=order)
232+
buf = dpt_ext.empty_like(val_ary, dtype=res_dt, order=order)
233233
_manager = SequentialOrderManager[exec_q]
234234
dep_evs = _manager.submitted_events
235235
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
@@ -242,7 +242,7 @@ def _clip_none(x, val, out, order, _binary_fn):
242242
x, buf, res_dt, res_shape, res_usm_type, exec_q
243243
)
244244
else:
245-
out = dpt.empty(
245+
out = dpt_ext.empty(
246246
res_shape,
247247
dtype=res_dt,
248248
usm_type=res_usm_type,
@@ -251,8 +251,8 @@ def _clip_none(x, val, out, order, _binary_fn):
251251
)
252252

253253
if x_shape != res_shape:
254-
x = dpt.broadcast_to(x, res_shape)
255-
buf = dpt.broadcast_to(buf, res_shape)
254+
x = dpt_ext.broadcast_to(x, res_shape)
255+
buf = dpt_ext.broadcast_to(buf, res_shape)
256256
ht_binary_ev, binary_ev = _binary_fn(
257257
src1=x,
258258
src2=buf,
@@ -353,14 +353,14 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
353353

354354
if ti._array_overlap(x, out):
355355
if not ti._same_logical_tensors(x, out):
356-
out = dpt.empty_like(out)
356+
out = dpt_ext.empty_like(out)
357357
else:
358358
return out
359359
else:
360360
if order == "K":
361361
out = _empty_like_orderK(x, x.dtype)
362362
else:
363-
out = dpt.empty_like(x, order=order)
363+
out = dpt_ext.empty_like(x, order=order)
364364

365365
_manager = SequentialOrderManager[exec_q]
366366
dep_evs = _manager.submitted_events
@@ -519,32 +519,32 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
519519

520520
if ti._array_overlap(x, out):
521521
if not ti._same_logical_tensors(x, out):
522-
out = dpt.empty_like(out)
522+
out = dpt_ext.empty_like(out)
523523

524524
if isinstance(min, dpt.usm_ndarray):
525525
if (
526526
ti._array_overlap(min, out)
527527
and not ti._same_logical_tensors(min, out)
528528
and buf1_dt is None
529529
):
530-
out = dpt.empty_like(out)
530+
out = dpt_ext.empty_like(out)
531531

532532
if isinstance(max, dpt.usm_ndarray):
533533
if (
534534
ti._array_overlap(max, out)
535535
and not ti._same_logical_tensors(max, out)
536536
and buf2_dt is None
537537
):
538-
out = dpt.empty_like(out)
538+
out = dpt_ext.empty_like(out)
539539

540540
if isinstance(min, dpt.usm_ndarray):
541541
a_min = min
542542
else:
543-
a_min = dpt.asarray(min, dtype=min_dtype, sycl_queue=exec_q)
543+
a_min = dpt_ext.asarray(min, dtype=min_dtype, sycl_queue=exec_q)
544544
if isinstance(max, dpt.usm_ndarray):
545545
a_max = max
546546
else:
547-
a_max = dpt.asarray(max, dtype=max_dtype, sycl_queue=exec_q)
547+
a_max = dpt_ext.asarray(max, dtype=max_dtype, sycl_queue=exec_q)
548548

549549
if order == "A":
550550
order = (
@@ -572,19 +572,19 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
572572
exec_q,
573573
)
574574
else:
575-
out = dpt.empty(
575+
out = dpt_ext.empty(
576576
res_shape,
577577
dtype=res_dt,
578578
usm_type=res_usm_type,
579579
sycl_queue=exec_q,
580580
order=order,
581581
)
582582
if x_shape != res_shape:
583-
x = dpt.broadcast_to(x, res_shape)
583+
x = dpt_ext.broadcast_to(x, res_shape)
584584
if a_min.shape != res_shape:
585-
a_min = dpt.broadcast_to(a_min, res_shape)
585+
a_min = dpt_ext.broadcast_to(a_min, res_shape)
586586
if a_max.shape != res_shape:
587-
a_max = dpt.broadcast_to(a_max, res_shape)
587+
a_max = dpt_ext.broadcast_to(a_max, res_shape)
588588
_manager = SequentialOrderManager[exec_q]
589589
dep_ev = _manager.submitted_events
590590
ht_binary_ev, binary_ev = ti._clip(
@@ -612,7 +612,7 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
612612
if order == "K":
613613
buf2 = _empty_like_orderK(a_max, buf2_dt)
614614
else:
615-
buf2 = dpt.empty_like(a_max, dtype=buf2_dt, order=order)
615+
buf2 = dpt_ext.empty_like(a_max, dtype=buf2_dt, order=order)
616616
_manager = SequentialOrderManager[exec_q]
617617
dep_ev = _manager.submitted_events
618618
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
@@ -631,18 +631,18 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
631631
exec_q,
632632
)
633633
else:
634-
out = dpt.empty(
634+
out = dpt_ext.empty(
635635
res_shape,
636636
dtype=res_dt,
637637
usm_type=res_usm_type,
638638
sycl_queue=exec_q,
639639
order=order,
640640
)
641641

642-
x = dpt.broadcast_to(x, res_shape)
642+
x = dpt_ext.broadcast_to(x, res_shape)
643643
if a_min.shape != res_shape:
644-
a_min = dpt.broadcast_to(a_min, res_shape)
645-
buf2 = dpt.broadcast_to(buf2, res_shape)
644+
a_min = dpt_ext.broadcast_to(a_min, res_shape)
645+
buf2 = dpt_ext.broadcast_to(buf2, res_shape)
646646
ht_binary_ev, binary_ev = ti._clip(
647647
src=x,
648648
min=a_min,
@@ -668,7 +668,7 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
668668
if order == "K":
669669
buf1 = _empty_like_orderK(a_min, buf1_dt)
670670
else:
671-
buf1 = dpt.empty_like(a_min, dtype=buf1_dt, order=order)
671+
buf1 = dpt_ext.empty_like(a_min, dtype=buf1_dt, order=order)
672672
_manager = SequentialOrderManager[exec_q]
673673
dep_ev = _manager.submitted_events
674674
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
@@ -687,18 +687,18 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
687687
exec_q,
688688
)
689689
else:
690-
out = dpt.empty(
690+
out = dpt_ext.empty(
691691
res_shape,
692692
dtype=res_dt,
693693
usm_type=res_usm_type,
694694
sycl_queue=exec_q,
695695
order=order,
696696
)
697697

698-
x = dpt.broadcast_to(x, res_shape)
699-
buf1 = dpt.broadcast_to(buf1, res_shape)
698+
x = dpt_ext.broadcast_to(x, res_shape)
699+
buf1 = dpt_ext.broadcast_to(buf1, res_shape)
700700
if a_max.shape != res_shape:
701-
a_max = dpt.broadcast_to(a_max, res_shape)
701+
a_max = dpt_ext.broadcast_to(a_max, res_shape)
702702
ht_binary_ev, binary_ev = ti._clip(
703703
src=x,
704704
min=buf1,
@@ -736,7 +736,7 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
736736
if order == "K":
737737
buf1 = _empty_like_orderK(a_min, buf1_dt)
738738
else:
739-
buf1 = dpt.empty_like(a_min, dtype=buf1_dt, order=order)
739+
buf1 = dpt_ext.empty_like(a_min, dtype=buf1_dt, order=order)
740740

741741
_manager = SequentialOrderManager[exec_q]
742742
dep_evs = _manager.submitted_events
@@ -747,7 +747,7 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
747747
if order == "K":
748748
buf2 = _empty_like_orderK(a_max, buf2_dt)
749749
else:
750-
buf2 = dpt.empty_like(a_max, dtype=buf2_dt, order=order)
750+
buf2 = dpt_ext.empty_like(a_max, dtype=buf2_dt, order=order)
751751
ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray(
752752
src=a_max, dst=buf2, sycl_queue=exec_q, depends=dep_evs
753753
)
@@ -758,17 +758,17 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
758758
x, buf1, buf2, res_dt, res_shape, res_usm_type, exec_q
759759
)
760760
else:
761-
out = dpt.empty(
761+
out = dpt_ext.empty(
762762
res_shape,
763763
dtype=res_dt,
764764
usm_type=res_usm_type,
765765
sycl_queue=exec_q,
766766
order=order,
767767
)
768768

769-
x = dpt.broadcast_to(x, res_shape)
770-
buf1 = dpt.broadcast_to(buf1, res_shape)
771-
buf2 = dpt.broadcast_to(buf2, res_shape)
769+
x = dpt_ext.broadcast_to(x, res_shape)
770+
buf1 = dpt_ext.broadcast_to(buf1, res_shape)
771+
buf2 = dpt_ext.broadcast_to(buf2, res_shape)
772772
ht_, clip_ev = ti._clip(
773773
src=x,
774774
min=buf1,

0 commit comments

Comments
 (0)