@@ -373,9 +373,25 @@ def test_add_inplace_dtype_matrix(op1_dtype, op2_dtype):
373373 else :
374374 with pytest .raises (ValueError ):
375375 ar1 += ar2
376+
377+ ar1 = dpt .ones (sz , dtype = op1_dtype )
378+ ar2 = dpt .ones_like (ar1 , dtype = op2_dtype )
379+ if _can_cast (ar2 .dtype , ar1 .dtype , _fp16 , _fp64 ):
380+ dpt .add (ar1 , ar2 , out = ar1 )
381+ assert (
382+ dpt .asnumpy (ar1 ) == np .full (ar1 .shape , 2 , dtype = ar1 .dtype )
383+ ).all ()
384+
385+ ar3 = dpt .ones (sz , dtype = op1_dtype )[::- 1 ]
386+ ar4 = dpt .ones (2 * sz , dtype = op2_dtype )[::2 ]
387+ dpt .add (ar3 , ar4 , out = ar3 )
388+ assert (
389+ dpt .asnumpy (ar3 ) == np .full (ar3 .shape , 2 , dtype = ar3 .dtype )
390+ ).all ()
391+ else :
392+ with pytest .raises (ValueError ):
376393 dpt .add (ar1 , ar2 , out = ar1 )
377394
378- # out is second arg
379395 ar1 = dpt .ones (sz , dtype = op1_dtype )
380396 ar2 = dpt .ones_like (ar1 , dtype = op2_dtype )
381397 if _can_cast (ar1 .dtype , ar2 .dtype , _fp16 , _fp64 ):
@@ -401,7 +417,7 @@ def test_add_inplace_broadcasting():
401417 m = dpt .ones ((100 , 5 ), dtype = "i4" )
402418 v = dpt .arange (5 , dtype = "i4" )
403419
404- m += v
420+ dpt . add ( m , v , out = m )
405421 assert (dpt .asnumpy (m ) == np .arange (1 , 6 , dtype = "i4" )[np .newaxis , :]).all ()
406422
407423 # check case where second arg is out
@@ -411,6 +427,26 @@ def test_add_inplace_broadcasting():
411427 ).all ()
412428
413429
430+ def test_add_inplace_operator_broadcasting ():
431+ get_queue_or_skip ()
432+
433+ m = dpt .ones ((100 , 5 ), dtype = "i4" )
434+ v = dpt .arange (5 , dtype = "i4" )
435+
436+ m += v
437+ assert (dpt .asnumpy (m ) == np .arange (1 , 6 , dtype = "i4" )[np .newaxis , :]).all ()
438+
439+
440+ def test_add_inplace_operator_mutual_broadcast ():
441+ get_queue_or_skip ()
442+
443+ x1 = dpt .ones ((1 , 10 ), dtype = "i4" )
444+ x2 = dpt .ones ((10 , 1 ), dtype = "i4" )
445+
446+ with pytest .raises (ValueError ):
447+ dpt .add ._inplace_op (x1 , x2 )
448+
449+
414450def test_add_inplace_errors ():
415451 get_queue_or_skip ()
416452 try :
@@ -425,27 +461,45 @@ def test_add_inplace_errors():
425461 ar1 = dpt .ones (2 , dtype = "float32" , sycl_queue = gpu_queue )
426462 ar2 = dpt .ones_like (ar1 , sycl_queue = cpu_queue )
427463 with pytest .raises (ExecutionPlacementError ):
428- ar1 += ar2
464+ dpt . add ( ar1 , ar2 , out = ar1 )
429465
430466 ar1 = dpt .ones (2 , dtype = "float32" )
431467 ar2 = dpt .ones (3 , dtype = "float32" )
432468 with pytest .raises (ValueError ):
433- ar1 += ar2
469+ dpt . add ( ar1 , ar2 , out = ar1 )
434470
435471 ar1 = np .ones (2 , dtype = "float32" )
436472 ar2 = dpt .ones (2 , dtype = "float32" )
437473 with pytest .raises (TypeError ):
438- ar1 += ar2
474+ dpt . add ( ar1 , ar2 , out = ar1 )
439475
440476 ar1 = dpt .ones (2 , dtype = "float32" )
441477 ar2 = dict ()
442478 with pytest .raises (ValueError ):
443- ar1 += ar2
479+ dpt . add ( ar1 , ar2 , out = ar1 )
444480
445481 ar1 = dpt .ones ((2 , 1 ), dtype = "float32" )
446482 ar2 = dpt .ones ((1 , 2 ), dtype = "float32" )
447483 with pytest .raises (ValueError ):
448- ar1 += ar2
484+ dpt .add (ar1 , ar2 , out = ar1 )
485+
486+
487+ def test_add_inplace_operator_errors ():
488+ q1 = get_queue_or_skip ()
489+ q2 = get_queue_or_skip ()
490+
491+ x = dpt .ones (10 , dtype = "i4" , sycl_queue = q1 )
492+ with pytest .raises (TypeError ):
493+ dpt .add ._inplace_op (dict (), x )
494+
495+ x .flags ["W" ] = False
496+ with pytest .raises (ValueError ):
497+ dpt .add ._inplace_op (x , 2 )
498+
499+ x_q1 = dpt .ones (10 , dtype = "i4" , sycl_queue = q1 )
500+ x_q2 = dpt .ones (10 , dtype = "i4" , sycl_queue = q2 )
501+ with pytest .raises (ExecutionPlacementError ):
502+ dpt .add ._inplace_op (x_q1 , x_q2 )
449503
450504
451505def test_add_inplace_same_tensors ():
0 commit comments