@@ -535,6 +535,21 @@ def test_put_basic_axis():
535535 assert (expected == dpt .asnumpy (x )).all ()
536536
537537
538+ @pytest .mark .parametrize ("data_dt" , _all_dtypes )
539+ def test_put_0d_val (data_dt ):
540+ q = get_queue_or_skip ()
541+ skip_if_dtype_not_supported (data_dt , q )
542+
543+ x = dpt .arange (5 , dtype = data_dt , sycl_queue = q )
544+ ind = dpt .asarray ([0 ], dtype = np .intp , sycl_queue = q )
545+ x [ind ] = 2
546+ assert_array_equal (np .asarray (2 , dtype = data_dt ), dpt .asnumpy (x [0 ]))
547+
548+ x = dpt .asarray (5 , dtype = data_dt , sycl_queue = q )
549+ x [ind ] = 2
550+ assert_array_equal (np .asarray (2 , dtype = data_dt ), dpt .asnumpy (x ))
551+
552+
538553@pytest .mark .parametrize (
539554 "data_dt" ,
540555 _all_dtypes ,
@@ -543,8 +558,8 @@ def test_take_0d_data(data_dt):
543558 q = get_queue_or_skip ()
544559 skip_if_dtype_not_supported (data_dt , q )
545560
546- x = dpt .asarray (0 , dtype = data_dt )
547- ind = dpt .arange (5 )
561+ x = dpt .asarray (0 , dtype = data_dt , sycl_queue = q )
562+ ind = dpt .arange (5 , dtype = np . intp , sycl_queue = q )
548563
549564 y = dpt .take (x , ind )
550565 assert (
@@ -561,9 +576,9 @@ def test_put_0d_data(data_dt):
561576 q = get_queue_or_skip ()
562577 skip_if_dtype_not_supported (data_dt , q )
563578
564- x = dpt .asarray (0 , dtype = data_dt )
565- ind = dpt .arange (5 )
566- val = dpt .asarray (2 , dtype = data_dt )
579+ x = dpt .asarray (0 , dtype = data_dt , sycl_queue = q )
580+ ind = dpt .arange (5 , dtype = np . intp , sycl_queue = q )
581+ val = dpt .asarray (2 , dtype = data_dt , sycl_queue = q )
567582
568583 dpt .put (x , ind , val , axis = 0 )
569584 assert (
@@ -577,10 +592,10 @@ def test_put_0d_data(data_dt):
577592 _all_int_dtypes ,
578593)
579594def test_take_0d_ind (ind_dt ):
580- get_queue_or_skip ()
595+ q = get_queue_or_skip ()
581596
582- x = dpt .arange (5 , dtype = ind_dt )
583- ind = dpt .asarray (3 )
597+ x = dpt .arange (5 , dtype = "i4" , sycl_queue = q )
598+ ind = dpt .asarray (3 , dtype = ind_dt , sycl_queue = q )
584599
585600 y = dpt .take (x , ind )
586601 assert dpt .asnumpy (x [3 ]) == dpt .asnumpy (y )
@@ -591,11 +606,11 @@ def test_take_0d_ind(ind_dt):
591606 _all_int_dtypes ,
592607)
593608def test_put_0d_ind (ind_dt ):
594- get_queue_or_skip ()
609+ q = get_queue_or_skip ()
595610
596- x = dpt .arange (5 , dtype = ind_dt )
597- ind = dpt .asarray (3 )
598- val = dpt .asarray (5 , dtype = ind_dt )
611+ x = dpt .arange (5 , dtype = "i4" , sycl_queue = q )
612+ ind = dpt .asarray (3 , dtype = ind_dt , sycl_queue = q )
613+ val = dpt .asarray (5 , dtype = x . dtype , sycl_queue = q )
599614
600615 dpt .put (x , ind , val , axis = 0 )
601616 assert dpt .asnumpy (x [3 ]) == dpt .asnumpy (val )
@@ -750,7 +765,7 @@ def test_put_strided_1d_destination(data_dt, order):
750765
751766 x = dpt .arange (27 , dtype = data_dt , sycl_queue = q )
752767 ind = dpt .arange (4 , 9 , dtype = np .intp , sycl_queue = q )
753- val = dpt .asarray (9 , dtype = data_dt , sycl_queue = q )
768+ val = dpt .asarray (9 , dtype = x . dtype , sycl_queue = q )
754769
755770 x_np = dpt .asnumpy (x )
756771 ind_np = dpt .asnumpy (ind )
@@ -780,7 +795,7 @@ def test_put_strided_destination(data_dt, order):
780795
781796 x = dpt .reshape (_make_3d (data_dt , q ), (9 , 3 ), order = order )
782797 ind = dpt .arange (2 , dtype = np .intp , sycl_queue = q )
783- val = dpt .asarray (9 , dtype = data_dt , sycl_queue = q )
798+ val = dpt .asarray (9 , dtype = x . dtype , sycl_queue = q )
784799
785800 x_np = dpt .asnumpy (x )
786801 ind_np = dpt .asnumpy (ind )
@@ -825,7 +840,7 @@ def test_put_strided_1d_indices(ind_dt):
825840
826841 x = dpt .arange (27 , dtype = "i4" , sycl_queue = q )
827842 ind = dpt .arange (12 , 24 , dtype = ind_dt , sycl_queue = q )
828- val = dpt .asarray (- 1 , dtype = "i4" , sycl_queue = q )
843+ val = dpt .asarray (- 1 , dtype = x . dtype , sycl_queue = q )
829844
830845 x_np = dpt .asnumpy (x )
831846 ind_np = dpt .asnumpy (ind ).astype (np .intp )
@@ -880,43 +895,53 @@ def test_put_strided_indices(ind_dt, order):
880895
881896
882897def test_take_arg_validation ():
883- get_queue_or_skip ()
898+ q = get_queue_or_skip ()
884899
885- x = dpt .arange (4 )
886- ind0 = dpt .arange (2 )
887- ind1 = dpt .arange (2.0 )
900+ x = dpt .arange (4 , dtype = "i4" , sycl_queue = q )
901+ ind0 = dpt .arange (2 , dtype = np . intp , sycl_queue = q )
902+ ind1 = dpt .arange (2.0 , dtype = "f" , sycl_queue = q )
888903
889- with pytest .raises (ValueError ):
890- dpt .take (dpt .reshape (x , (2 , 2 )), ind0 )
891904 with pytest .raises (TypeError ):
892905 dpt .take (dict (), ind0 , axis = 0 )
893906 with pytest .raises (TypeError ):
894907 dpt .take (x , dict (), axis = 0 )
895908 with pytest .raises (TypeError ):
909+ x [[]]
910+ with pytest .raises (IndexError ):
896911 dpt .take (x , ind1 , axis = 0 )
912+ with pytest .raises (IndexError ):
913+ x [ind1 ]
897914
915+ with pytest .raises (ValueError ):
916+ dpt .take (dpt .reshape (x , (2 , 2 )), ind0 )
898917 with pytest .raises (ValueError ):
899918 dpt .take (x , ind0 , mode = 0 )
900919 with pytest .raises (ValueError ):
901920 dpt .take (dpt .reshape (x , (2 , 2 )), ind0 , axis = None )
902921
903922
904923def test_put_arg_validation ():
905- get_queue_or_skip ()
924+ q = get_queue_or_skip ()
906925
907- x = dpt .arange (4 )
908- ind0 = dpt .arange (2 )
909- ind1 = dpt .arange (2.0 )
910- val = dpt .asarray (2 )
926+ x = dpt .arange (4 , dtype = "i4" , sycl_queue = q )
927+ ind0 = dpt .arange (2 , dtype = np . intp , sycl_queue = q )
928+ ind1 = dpt .arange (2.0 , dtype = "f" , sycl_queue = q )
929+ val = dpt .asarray (2 , x . dtype , sycl_queue = q )
911930
912931 with pytest .raises (TypeError ):
913932 dpt .put (dict (), ind0 , val , axis = 0 )
914933 with pytest .raises (TypeError ):
915934 dpt .put (x , dict (), val , axis = 0 )
916935 with pytest .raises (TypeError ):
936+ x [[]] = val
937+ with pytest .raises (IndexError ):
917938 dpt .put (x , ind1 , val , axis = 0 )
939+ with pytest .raises (IndexError ):
940+ x [ind1 ] = val
918941 with pytest .raises (TypeError ):
919942 dpt .put (x , ind0 , dict (), axis = 0 )
943+ with pytest .raises (TypeError ):
944+ x [ind0 ] = dict ()
920945
921946 with pytest .raises (ValueError ):
922947 dpt .put (x , ind0 , val , mode = 0 )
0 commit comments