@@ -1427,6 +1427,57 @@ def test_tril_order_k(order, k):
14271427 assert np .array_equal (Ynp , dpt .asnumpy (Y ))
14281428
14291429
1430+ def test_meshgrid ():
1431+ try :
1432+ q = dpctl .SyclQueue ()
1433+ except dpctl .SyclQueueCreationError :
1434+ pytest .skip ("Queue could not be created" )
1435+ X = dpt .arange (5 , sycl_queue = q )
1436+ Y = dpt .arange (3 , sycl_queue = q )
1437+ Z = dpt .meshgrid (X , Y )
1438+ Znp = np .meshgrid (dpt .asnumpy (X ), dpt .asnumpy (Y ))
1439+ n = len (Z )
1440+ assert n == len (Znp )
1441+ for i in range (n ):
1442+ assert np .array_equal (dpt .asnumpy (Z [i ]), Znp [i ])
1443+ # dimension > 1 must raise ValueError
1444+ with pytest .raises (ValueError ):
1445+ dpt .meshgrid (dpt .usm_ndarray ((4 , 4 )))
1446+ # unknown indexing kwarg must raise ValueError
1447+ with pytest .raises (ValueError ):
1448+ dpt .meshgrid (X , indexing = "ji" )
1449+ # input arrays with different data types must raise ValueError
1450+ with pytest .raises (ValueError ):
1451+ dpt .meshgrid (X , dpt .asarray (Y , dtype = "b1" ))
1452+
1453+
1454+ def test_meshgrid2 ():
1455+ try :
1456+ q1 = dpctl .SyclQueue ()
1457+ q2 = dpctl .SyclQueue ()
1458+ q3 = dpctl .SyclQueue ()
1459+ except dpctl .SyclQueueCreationError :
1460+ pytest .skip ("Queue could not be created" )
1461+ x1 = dpt .arange (0 , 2 , dtype = "int16" , sycl_queue = q1 )
1462+ x2 = dpt .arange (3 , 6 , dtype = "int16" , sycl_queue = q2 )
1463+ x3 = dpt .arange (6 , 10 , dtype = "int16" , sycl_queue = q3 )
1464+ y1 , y2 , y3 = dpt .meshgrid (x1 , x2 , x3 , indexing = "xy" )
1465+ z1 , z2 , z3 = dpt .meshgrid (x1 , x2 , x3 , indexing = "ij" )
1466+ assert all (
1467+ x .sycl_queue == y .sycl_queue for x , y in zip ((x1 , x2 , x3 ), (y1 , y2 , y3 ))
1468+ )
1469+ assert all (
1470+ x .sycl_queue == z .sycl_queue for x , z in zip ((x1 , x2 , x3 ), (z1 , z2 , z3 ))
1471+ )
1472+ assert y1 .shape == y2 .shape and y2 .shape == y3 .shape
1473+ assert z1 .shape == z2 .shape and z2 .shape == z3 .shape
1474+ assert y1 .shape == (len (x2 ), len (x1 ), len (x3 ))
1475+ assert z1 .shape == (len (x1 ), len (x2 ), len (x3 ))
1476+ # FIXME: uncomment out once gh-921 is merged
1477+ # assert all(z.flags["C"] for z in (z1, z2, z3))
1478+ # assert all(y.flags["C"] for y in (y1, y2, y3))
1479+
1480+
14301481def test_common_arg_validation ():
14311482 order = "I"
14321483 # invalid order must raise ValueError
@@ -1463,3 +1514,5 @@ def test_common_arg_validation():
14631514 dpt .tril (X )
14641515 with pytest .raises (TypeError ):
14651516 dpt .triu (X )
1517+ with pytest .raises (TypeError ):
1518+ dpt .meshgrid (X )
0 commit comments