3939
4040
4141import numpy
42+ from numpy .core .numeric import normalize_axis_tuple
4243
4344import dpnp
4445from dpnp .dpnp_algo import *
@@ -66,9 +67,9 @@ def dot(a, b, out=None):
6667
6768 Parameters
6869 ----------
69- a : {dpnp_array , usm_ndarray, scalar}
70+ a : {dpnp.ndarray , usm_ndarray, scalar}
7071 First input array. Both inputs `a` and `b` can not be scalars at the same time.
71- b : {dpnp_array , usm_ndarray, scalar}
72+ b : {dpnp.ndarray , usm_ndarray, scalar}
7273 Second input array. Both inputs `a` and `b` can not be scalars at the same time.
7374 out : {dpnp.ndarray, usm_ndarray}, optional
7475 Alternative output array in which to place the result. It must have
@@ -404,42 +405,152 @@ def outer(x1, x2, out=None):
404405 return call_origin (numpy .outer , x1 , x2 , out = out )
405406
406407
407- def tensordot (x1 , x2 , axes = 2 ):
408- """
408+ def tensordot (a , b , axes = 2 ):
409+ r """
409410 Compute tensor dot product along specified axes.
410411
411412 For full documentation refer to :obj:`numpy.tensordot`.
412413
413- Limitations
414- -----------
415- Parameters `x1` and `x2` are supported as :obj:`dpnp.ndarray`.
416- Keyword argument `kwargs` is currently unsupported.
417- Parameter `axes` is supported only with value ``1``.
418- Otherwise the functions will be executed sequentially on CPU.
419- Input array data types are limited by supported DPNP :ref:`Data types`.
414+ Parameters
415+ ----------
416+ a : {dpnp.ndarray, usm_ndarray, scalar}
417+ First input array. Both inputs `a` and `b` can not be scalars at the same time.
418+ b : {dpnp.ndarray, usm_ndarray, scalar}
419+ Second input array. Both inputs `a` and `b` can not be scalars at the same time.
420+ axes : int or (2,) array_like
421+ * integer_like
422+ If an int `N`, sum over the last `N` axes of `a` and the first `N` axes
423+ of `b` in order. The sizes of the corresponding axes must match.
424+ * (2,) array_like
425+ Or, a list of axes to be summed over, first sequence applying to `a`,
426+ second to `b`. Both elements array_like must be of the same length.
427+
428+ Returns
429+ -------
430+ out : dpnp.ndarray
431+ Returns the tensordot product of `a` and `b`.
420432
421433 See Also
422434 --------
423435 :obj:`dpnp.dot` : Returns the dot product.
424436 :obj:`dpnp.einsum` : Evaluates the Einstein summation convention on the operands.
425437
438+ Notes
439+ -----
440+ Three common use cases are:
441+ * ``axes = 0`` : tensor product :math:`a \otimes b`
442+ * ``axes = 1`` : tensor dot product :math:`a \cdot b`
443+ * ``axes = 2`` : (default) tensor double contraction :math:`a:b`
444+
445+ When `axes` is integer, the sequence for evaluation will be: first
446+ the -Nth axis in `a` and 0th axis in `b`, and the -1th axis in `a` and
447+ Nth axis in `b` last.
448+
449+ When there is more than one axis to sum over - and they are not the last
450+ (first) axes of `a` (`b`) - the argument `axes` should consist of
451+ two sequences of the same length, with the first axis to sum over given
452+ first in both sequences, the second axis second, and so forth.
453+
454+ The shape of the result consists of the non-contracted axes of the
455+ first tensor, followed by the non-contracted axes of the second.
456+
426457 Examples
427458 --------
428459 >>> import dpnp as np
429460 >>> a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
430461 >>> b = np.array([1, 2, 3])
431- >>> result = np.tensordot(a, b, 1)
432- >>> [x for x in result]
433- [14, 32, 50]
462+ >>> np.tensordot(a, b, 1)
463+ array([14, 32, 50])
464+
465+ >>> a = np.arange(60.).reshape(3,4,5)
466+ >>> b = np.arange(24.).reshape(4,3,2)
467+ >>> c = np.tensordot(a,b, axes=([1,0],[0,1]))
468+ >>> c.shape
469+ (5, 2)
470+ >>> c
471+ array([[4400., 4730.],
472+ [4532., 4874.],
473+ [4664., 5018.],
474+ [4796., 5162.],
475+ [4928., 5306.]])
476+
477+ A slower but equivalent way of computing the same...
478+
479+ >>> d = np.zeros((5,2))
480+ >>> for i in range(5):
481+ ... for j in range(2):
482+ ... for k in range(3):
483+ ... for n in range(4):
484+ ... d[i,j] += a[k,n,i] * b[n,k,j]
485+ >>> c == d
486+ array([[ True, True],
487+ [ True, True],
488+ [ True, True],
489+ [ True, True],
490+ [ True, True]])
434491
435492 """
436493
437- x1_desc = dpnp .get_dpnp_descriptor (x1 , copy_when_nondefault_queue = False )
438- x2_desc = dpnp .get_dpnp_descriptor (x2 , copy_when_nondefault_queue = False )
439- if x1_desc and x2_desc and (axes == 1 ):
440- return dpnp_tensordot_not_implemented (x1_desc , x2_desc ) # dpnp_matmul
494+ dpnp .check_supported_arrays_type (a , b , scalar_type = True )
441495
442- return call_origin (numpy .tensordot , x1 , x2 , axes )
496+ if dpnp .isscalar (a ):
497+ a = dpnp .array (a , sycl_queue = b .sycl_queue , usm_type = b .usm_type )
498+ elif dpnp .isscalar (b ):
499+ b = dpnp .array (b , sycl_queue = a .sycl_queue , usm_type = a .usm_type )
500+
501+ try :
502+ iter (axes )
503+ except Exception :
504+ if not isinstance (axes , int ):
505+ raise TypeError ("Axes must be an integer." )
506+ axes_a = tuple (range (- axes , 0 ))
507+ axes_b = tuple (range (0 , axes ))
508+ else :
509+ if len (axes ) != 2 :
510+ raise ValueError ("Axes must consist of two sequences." )
511+
512+ axes_a , axes_b = axes
513+ axes_a = (axes_a ,) if dpnp .isscalar (axes_a ) else axes_a
514+ axes_b = (axes_b ,) if dpnp .isscalar (axes_b ) else axes_b
515+
516+ if len (axes_a ) != len (axes_b ):
517+ raise ValueError ("Axes length mismatch." )
518+
519+ a_shape = a .shape
520+ b_shape = b .shape
521+ for axis_a , axis_b in zip (axes_a , axes_b ):
522+ if a_shape [axis_a ] != b_shape [axis_b ]:
523+ raise ValueError (
524+ "shape of input arrays is not similar at requested axes."
525+ )
526+
527+ # Make the axes non-negative
528+ a_ndim = a .ndim
529+ b_ndim = b .ndim
530+ axes_a = normalize_axis_tuple (axes_a , a_ndim , "axis" )
531+ axes_b = normalize_axis_tuple (axes_b , b_ndim , "axis" )
532+
533+ # Move the axes to sum over, to the end of "a"
534+ notin = tuple (k for k in range (a_ndim ) if k not in axes_a )
535+ newaxes_a = notin + axes_a
536+ N1 = int (numpy .prod ([a_shape [ax ] for ax in notin ]))
537+ N2 = int (numpy .prod ([a_shape [ax ] for ax in axes_a ]))
538+ newshape_a = (N1 , N2 )
539+ olda = [a_shape [axis ] for axis in notin ]
540+
541+ # Move the axes to sum over, to the front of "b"
542+ notin = tuple (k for k in range (b_ndim ) if k not in axes_b )
543+ newaxes_b = tuple (axes_b + notin )
544+ N1 = int (numpy .prod ([b_shape [ax ] for ax in axes_b ]))
545+ N2 = int (numpy .prod ([b_shape [ax ] for ax in notin ]))
546+ newshape_b = (N1 , N2 )
547+ oldb = [b_shape [axis ] for axis in notin ]
548+
549+ at = a .transpose (newaxes_a ).reshape (newshape_a )
550+ bt = b .transpose (newaxes_b ).reshape (newshape_b )
551+ res = dpnp .matmul (at , bt )
552+
553+ return res .reshape (olda + oldb )
443554
444555
445556def vdot (a , b ):
@@ -450,11 +561,11 @@ def vdot(a, b):
450561
451562 Parameters
452563 ----------
453- a : {dpnp_array , usm_ndarray, scalar}
564+ a : {dpnp.ndarray , usm_ndarray, scalar}
454565 First input array. Both inputs `a` and `b` can not be
455566 scalars at the same time. If `a` is complex, the complex
456567 conjugate is taken before the calculation of the dot product.
457- b : {dpnp_array , usm_ndarray, scalar}
568+ b : {dpnp.ndarray , usm_ndarray, scalar}
458569 Second input array. Both inputs `a` and `b` can not be
459570 scalars at the same time.
460571
0 commit comments