4040import numpy
4141
4242
43- try :
44- """
45- Detect DPCtl availability to use data container
46- """
47- import dpctl .tensor as dpctl
48-
49- config .__DPNP_DPCTL_AVAILABLE__ = True
50-
51- except ImportError :
52- """
53- No DPCtl data container available
54- """
55- config .__DPNP_DPCTL_AVAILABLE__ = False
43+ if config .__DPNP_OUTPUT_DPCTL__ :
44+ try :
45+ """
46+ Detect DPCtl availability to use data container
47+ """
48+ import dpctl .tensor as dpctl
5649
57- # config.__DPNP_DPCTL_AVAILABLE__ = False
50+ except ImportError :
51+ """
52+ No DPCtl data container available
53+ """
54+ config .__DPNP_OUTPUT_DPCTL__ = 0
5855
5956
6057__all__ = [
@@ -67,11 +64,38 @@ def create_output_container(shape, type):
6764 """ Create NumPy ndarray """
6865 # TODO need to use "buffer=" parameter to use SYCL aware memory
6966 result = numpy .ndarray (shape , dtype = type )
70- elif config .__DPNP_DPCTL_AVAILABLE__ :
67+ elif config .__DPNP_OUTPUT_DPCTL__ :
7168 """ Create DPCTL array """
72- result = dpctl .usm_ndarray (shape , dtype = numpy .dtype (type ).name )
69+ if config .__DPNP_OUTPUT_DPCTL_DEFAULT_SHARED__ :
70+ """
71+ From DPCtrl documentation:
72+ 'buffer can be strings ('device'|'shared'|'host' to allocate new memory)'
73+ """
74+ result = dpctl .usm_ndarray (shape , dtype = numpy .dtype (type ).name , buffer = 'shared' )
75+ else :
76+ """
77+ Can't pass 'None' as buffer= parameter to allow DPCtrl uses it's default
78+ """
79+ result = dpctl .usm_ndarray (shape , dtype = numpy .dtype (type ).name )
7380 else :
7481 """ Create DPNP array """
7582 result = dparray (shape , dtype = type )
7683
7784 return result
85+
86+
87+ def container_copy (dst_obj , src_obj , dst_idx = 0 ):
88+ """
89+ Copy values to `dst` by iterating element by element in `input_obj`
90+ """
91+
92+ for elem_value in src_obj :
93+ if isinstance (elem_value , (list , tuple )):
94+ dst_idx = container_copy (dst_obj , elem_value , dst_idx )
95+ elif issubclass (type (elem_value ), (numpy .ndarray , dparray )):
96+ dst_idx = container_copy (dst_obj , elem_value , dst_idx )
97+ else :
98+ dst_obj .flat [dst_idx ] = elem_value
99+ dst_idx += 1
100+
101+ return dst_idx
0 commit comments