@@ -992,15 +992,36 @@ def test_full_dtype_inference():
992992def test_full_fill_array ():
993993 q = get_queue_or_skip ()
994994
995- Xnp = np .array ([1 , 2 , 3 ], dtype = np . int32 )
995+ Xnp = np .array ([1 , 2 , 3 ], dtype = "i4" )
996996 X = dpt .asarray (Xnp , sycl_queue = q )
997997
998998 shape = (3 , 3 )
999999 Y = dpt .full (shape , X )
10001000 Ynp = np .full (shape , Xnp )
10011001
1002+ assert Y .dtype == Ynp .dtype
1003+ assert Y .usm_type == "device"
10021004 assert np .array_equal (dpt .asnumpy (Y ), Ynp )
1003- assert Ynp .dtype == Y .dtype
1005+
1006+
1007+ def test_full_compute_follows_data ():
1008+ q1 = get_queue_or_skip ()
1009+ q2 = get_queue_or_skip ()
1010+
1011+ X = dpt .arange (10 , dtype = "i4" , sycl_queue = q1 , usm_type = "shared" )
1012+ Y = dpt .full (10 , X [3 ])
1013+
1014+ assert Y .dtype == X .dtype
1015+ assert Y .usm_type == X .usm_type
1016+ assert dpctl .utils .get_execution_queue ((Y .sycl_queue , X .sycl_queue ))
1017+ assert np .array_equal (dpt .asnumpy (Y ), np .full (10 , 3 , dtype = "i4" ))
1018+
1019+ Y = dpt .full (10 , X [3 ], dtype = "f4" , sycl_queue = q2 , usm_type = "host" )
1020+
1021+ assert Y .dtype == dpt .dtype ("f4" )
1022+ assert Y .usm_type == "host"
1023+ assert dpctl .utils .get_execution_queue ((Y .sycl_queue , q2 ))
1024+ assert np .array_equal (dpt .asnumpy (Y ), np .full (10 , 3 , dtype = "f4" ))
10041025
10051026
10061027@pytest .mark .parametrize (
0 commit comments