@@ -65,7 +65,13 @@ def _array_info_dispatch(obj):
6565 return _empty_tuple , int , _host_set
6666 if isinstance (obj , complex ):
6767 return _empty_tuple , complex , _host_set
68- if isinstance (obj , (list , tuple , range )):
68+ if isinstance (
69+ obj ,
70+ (
71+ list ,
72+ tuple ,
73+ ),
74+ ):
6975 return _array_info_sequence (obj )
7076 if _is_object_with_buffer_protocol (obj ):
7177 np_obj = np .array (obj )
@@ -329,7 +335,11 @@ def _usm_types_walker(o, usm_types_list):
329335 usm_ar = _usm_ndarray_from_suai (o )
330336 usm_types_list .append (usm_ar .usm_type )
331337 return
332- if isinstance (o , (list , tuple )):
338+ if _is_object_with_buffer_protocol (o ):
339+ return
340+ if isinstance (o , (int , bool , float , complex )):
341+ return
342+ if isinstance (o , (list , tuple , range )):
333343 for el in o :
334344 _usm_types_walker (el , usm_types_list )
335345 return
@@ -361,11 +371,37 @@ def _device_copy_walker(seq_o, res, events):
361371
362372def _copy_through_host_walker (seq_o , usm_res ):
363373 if isinstance (seq_o , dpt .usm_ndarray ):
364- usm_res [...] = dpt .asnumpy (seq_o ).copy ()
365- return
374+ if (
375+ dpctl .utils .get_execution_queue (
376+ (
377+ usm_res .sycl_queue ,
378+ seq_o .sycl_queue ,
379+ )
380+ )
381+ is None
382+ ):
383+ usm_res [...] = dpt .asnumpy (seq_o ).copy ()
384+ return
385+ else :
386+ usm_res [...] = seq_o
366387 if hasattr (seq_o , "__sycl_usm_array_interface__" ):
367388 usm_ar = _usm_ndarray_from_suai (seq_o )
368- usm_res [...] = dpt .asnumpy (usm_ar ).copy ()
389+ if (
390+ dpctl .utils .get_execution_queue (
391+ (
392+ usm_res .sycl_queue ,
393+ usm_ar .sycl_queue ,
394+ )
395+ )
396+ is None
397+ ):
398+ usm_res [...] = dpt .asnumpy (usm_ar ).copy ()
399+ else :
400+ usm_res [...] = usm_ar
401+ return
402+ if _is_object_with_buffer_protocol (seq_o ):
403+ np_ar = np .asarray (seq_o )
404+ usm_res [...] = np_ar
369405 return
370406 if isinstance (seq_o , (list , tuple )):
371407 for i , el in enumerate (seq_o ):
@@ -378,10 +414,10 @@ def _asarray_from_seq(
378414 seq_obj ,
379415 seq_shape ,
380416 seq_dt ,
381- seq_dev ,
417+ alloc_q ,
418+ exec_q ,
382419 dtype = None ,
383420 usm_type = None ,
384- sycl_queue = None ,
385421 order = "C" ,
386422):
387423 "`obj` is a sequence"
@@ -390,24 +426,13 @@ def _asarray_from_seq(
390426 _usm_types_walker (seq_obj , usm_types_in_seq )
391427 usm_type = dpctl .utils .get_coerced_usm_type (usm_types_in_seq )
392428 dpctl .utils .validate_usm_type (usm_type )
393- if sycl_queue is None :
394- exec_q = seq_dev
395- alloc_q = seq_dev
396- else :
397- exec_q = dpctl .utils .get_execution_queue (
398- (
399- sycl_queue ,
400- seq_dev ,
401- )
402- )
403- alloc_q = sycl_queue
404429 if dtype is None :
405430 dtype = _map_to_device_dtype (seq_dt , alloc_q )
406431 else :
407432 _mapped_dt = _map_to_device_dtype (dtype , alloc_q )
408433 if _mapped_dt != dtype :
409434 raise ValueError (
410- f"Device { sycl_queue .sycl_device } "
435+ f"Device { alloc_q .sycl_device } "
411436 f"does not support { dtype } natively."
412437 )
413438 dtype = _mapped_dt
@@ -437,6 +462,39 @@ def _asarray_from_seq(
437462 return res
438463
439464
465+ def _asarray_from_seq_single_device (
466+ obj ,
467+ seq_shape ,
468+ seq_dt ,
469+ seq_dev ,
470+ dtype = None ,
471+ usm_type = None ,
472+ sycl_queue = None ,
473+ order = "C" ,
474+ ):
475+ if sycl_queue is None :
476+ exec_q = seq_dev
477+ alloc_q = seq_dev
478+ else :
479+ exec_q = dpctl .utils .get_execution_queue (
480+ (
481+ sycl_queue ,
482+ seq_dev ,
483+ )
484+ )
485+ alloc_q = sycl_queue
486+ return _asarray_from_seq (
487+ obj ,
488+ seq_shape ,
489+ seq_dt ,
490+ alloc_q ,
491+ exec_q ,
492+ dtype = dtype ,
493+ usm_type = usm_type ,
494+ order = order ,
495+ )
496+
497+
440498def asarray (
441499 obj ,
442500 dtype = None ,
@@ -576,16 +634,42 @@ def asarray(
576634 order = order ,
577635 )
578636 elif len (devs ) == 1 :
579- return _asarray_from_seq (
637+ seq_dev = list (devs )[0 ]
638+ return _asarray_from_seq_single_device (
580639 obj ,
581640 seq_shape ,
582641 seq_dt ,
583- list ( devs )[ 0 ] ,
642+ seq_dev ,
584643 dtype = dtype ,
585644 usm_type = usm_type ,
586645 sycl_queue = sycl_queue ,
587646 order = order ,
588647 )
648+ elif len (devs ) > 1 :
649+ devs = [dev for dev in devs if dev is not None ]
650+ if sycl_queue is None :
651+ if len (devs ) == 1 :
652+ alloc_q = devs [0 ]
653+ else :
654+ raise dpctl .utils .ExecutionPlacementError (
655+ "Please specify `device` or `sycl_queue` keyword "
656+ "argument to determine where to allocate the "
657+ "resulting array."
658+ )
659+ else :
660+ alloc_q = sycl_queue
661+ return _asarray_from_seq (
662+ obj ,
663+ seq_shape ,
664+ seq_dt ,
665+ alloc_q ,
666+ # force copying via host
667+ None ,
668+ dtype = dtype ,
669+ usm_type = usm_type ,
670+ order = order ,
671+ )
672+
589673 raise NotImplementedError (
590674 "Converting Python sequences is not implemented"
591675 )
0 commit comments