5050from .dpnp_utils import *
5151
5252import dpnp
53+ from dpnp .dpnp_array import dpnp_array
5354
5455import numpy
5556import dpctl .tensor as dpt
@@ -173,7 +174,7 @@ def absolute(x,
173174 -------
174175 y : dpnp.ndarray
175176 An array containing the absolute value of each element in `x`.
176-
177+
177178 Limitations
178179 -----------
179180 Parameters `x` is only supported as either :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
@@ -601,7 +602,7 @@ def divide(x1,
601602 -------
602603 y : dpnp.ndarray
603604 The quotient ``x1/x2``, element-wise.
604-
605+
605606 Limitations
606607 -----------
607608 Parameters `x1` and `x2` are supported as either scalar, :class:`dpnp.ndarray`
@@ -1342,7 +1343,7 @@ def power(x1,
13421343 -------
13431344 y : dpnp.ndarray
13441345 The bases in `x1` raised to the exponents in `x2`.
1345-
1346+
13461347 Limitations
13471348 -----------
13481349 Parameters `x1` and `x2` are supported as either scalar, :class:`dpnp.ndarray`
@@ -1568,7 +1569,7 @@ def subtract(x1,
15681569 -------
15691570 y : dpnp.ndarray
15701571 The difference of `x1` and `x2`, element-wise.
1571-
1572+
15721573 Limitations
15731574 -----------
15741575 Parameters `x1` and `x2` are supported as either scalar, :class:`dpnp.ndarray`
@@ -1590,45 +1591,52 @@ def subtract(x1,
15901591 return _check_nd_call (numpy .subtract , dpnp_subtract , x1 , x2 , out = out , where = where , order = order , dtype = dtype , subok = subok , ** kwargs )
15911592
15921593
1593- def sum (x1 , axis = None , dtype = None , out = None , keepdims = False , initial = None , where = True ):
1594+ def sum (x , / , * , axis = None , dtype = None , keepdims = False , out = None , initial = 0 , where = True ):
15941595 """
15951596 Sum of array elements over a given axis.
15961597
15971598 For full documentation refer to :obj:`numpy.sum`.
15981599
1600+ Returns
1601+ -------
1602+ y : dpnp.ndarray
1603+ an array containing the sums. If the sum was computed over the
1604+ entire array, a zero-dimensional array is returned. The returned
1605+ array has the data type as described in the `dtype` parameter
1606+ of the Python Array API standard for the `sum` function.
1607+
15991608 Limitations
16001609 -----------
1601- Parameter `where`` is unsupported.
1602- Input array data types are limited by DPNP :ref:`Data types`.
1610+ Parameters `x` is supported as either :class:`dpnp.ndarray`
1611+ or :class:`dpctl.tensor.usm_ndarray`.
1612+ Parameters `out`, `initial` and `where` are supported with their default values.
1613+ Otherwise the function will be executed sequentially on CPU.
1614+ Input array data types are limited by supported DPNP :ref:`Data types`.
16031615
16041616 Examples
16051617 --------
16061618 >>> import dpnp as np
16071619 >>> np.sum(np.array([1, 2, 3, 4, 5]))
1608- 15
1609- >>> result = np.sum([[0, 1], [0, 5]], axis=0)
1610- [0, 6]
1620+ array(15)
1621+ >>> np.sum(np.array(5))
1622+ array(5)
1623+ >>> result = np.sum(np.array([[0, 1], [0, 5]]), axis=0)
1624+ array([0, 6])
16111625
16121626 """
16131627
1614- x1_desc = dpnp .get_dpnp_descriptor (x1 , copy_when_nondefault_queue = False )
1615- if x1_desc :
1616- if where is not True :
1617- pass
1618- else :
1619- if dpnp .isscalar (out ):
1620- raise TypeError ("output must be an array" )
1621- out_desc = dpnp .get_dpnp_descriptor (out , copy_when_nondefault_queue = False ) if out is not None else None
1622- result_obj = dpnp_sum (x1_desc , axis , dtype , out_desc , keepdims , initial , where ).get_pyobj ()
1623- result = dpnp .convert_single_elem_array_to_scalar (result_obj , keepdims )
16241628
1625- if x1_desc .size == 0 and axis is None :
1626- result = dpnp .zeros_like (result )
1627- if out is not None :
1628- out [...] = result
1629- return result
1629+ if out is not None :
1630+ pass
1631+ elif initial != 0 :
1632+ pass
1633+ elif where is not True :
1634+ pass
1635+ else :
1636+ y = dpt .sum (dpnp .get_usm_ndarray (x ), axis = axis , dtype = dtype , keepdims = keepdims )
1637+ return dpnp_array ._create_from_usm_ndarray (y )
16301638
1631- return call_origin (numpy .sum , x1 , axis = axis , dtype = dtype , out = out , keepdims = keepdims , initial = initial , where = where )
1639+ return call_origin (numpy .sum , x , axis = axis , dtype = dtype , out = out , keepdims = keepdims , initial = initial , where = where )
16321640
16331641
16341642def trapz (y1 , x1 = None , dx = 1.0 , axis = - 1 ):
0 commit comments