@@ -1190,7 +1190,7 @@ def test_np_linspace(config, dtype, endpoint, retstep):
11901190 np_ret = onp .linspace (config , endpoint = endpoint , retstep = retstep , dtype = dtype )
11911191 if retstep :
11921192 assert_almost_equal (mx_ret [0 ].asnumpy (), np_ret [0 ], atol = 1e-3 , rtol = 1e-5 )
1193- same (mx_ret [1 ], np_ret [1 ])
1193+ assert same (mx_ret [1 ], np_ret [1 ])
11941194 else :
11951195 assert_almost_equal (mx_ret .asnumpy (), np_ret , atol = 1e-3 , rtol = 1e-5 )
11961196
@@ -3735,13 +3735,13 @@ def forward(self, *arys):
37353735 np_out = funcs ["numpy" ][n ](* tensors_np )
37363736 for i in range (len (tensors )):
37373737 assert mx_out [i ].shape == np_out [i ].shape
3738- same (mx_out [i ].asnumpy (), np_out [i ])
3738+ assert same (mx_out [i ].asnumpy (), np_out [i ])
37393739
37403740 mx_out = funcs ["mxnet" ][n ](* tensors )
37413741 np_out = funcs ["numpy" ][n ](* tensors_np )
37423742 for i in range (len (tensors )):
37433743 assert mx_out [i ].shape == np_out [i ].shape
3744- same (mx_out [i ].asnumpy (), np_out [i ])
3744+ assert same (mx_out [i ].asnumpy (), np_out [i ])
37453745
37463746
37473747@use_np
@@ -5760,7 +5760,7 @@ def test_np_indices():
57605760 for shape in shapes :
57615761 np_out = onp .indices (dimensions = shape , dtype = dtype )
57625762 mx_out = np .indices (dimensions = shape , dtype = dtype )
5763- same (mx_out .asnumpy (), np_out )
5763+ assert same (mx_out .asnumpy (), np_out )
57645764 assert mx_out .shape == np_out .shape
57655765
57665766 @use_np
@@ -5782,7 +5782,7 @@ def forward(self, x):
57825782 if hybridize :
57835783 net .hybridize ()
57845784 mx_out = net (x )
5785- same (mx_out .asnumpy (), np_out )
5785+ assert same (mx_out .asnumpy (), np_out )
57865786 assert mx_out .shape == np_out .shape
57875787
57885788
@@ -8470,14 +8470,18 @@ def forward(self, a, indices):
84708470 return np .take (a , indices , axis = self ._axis , mode = self ._mode )
84718471
84728472 def grad_helper (grad_in , axis , idx , mode ):
8473- k = grad_in .shape [axis ]
8473+ k = 1 if axis == None else grad_in .shape [axis ]
84748474 if mode == 'clip' :
84758475 idx = 0 if idx < 0 else idx
84768476 idx = k - 1 if idx >= k else idx
84778477 else :
84788478 idx = idx % k
8479+
84798480 if axis == None :
8480- grad_in [idx ] += 1.0
8481+ if grad_in .shape == ():
8482+ grad_in += 1.0
8483+ else :
8484+ grad_in [idx ] += 1.0
84818485 elif axis == 0 :
84828486 if axis == len (grad_in .shape ) - 1 :
84838487 grad_in [idx ] += 1.0
@@ -8506,7 +8510,8 @@ def grad_helper(grad_in, axis, idx, mode):
85068510 def check_output_n_grad (data_shape , idx_shape , axis , mode ):
85078511 data_real = onp .random .normal (size = data_shape ).astype ('float32' )
85088512 idx_real = onp .random .randint (low = - 100 , high = 100 , size = idx_shape )
8509- same (np .take (np .array (data_real ), np .array (idx_real ), axis = axis , mode = mode ).asnumpy (),
8513+
8514+ assert same (np .take (np .array (data_real ), np .array (idx_real ), axis = axis , mode = mode ).asnumpy (),
85108515 onp .take (data_real , idx_real , axis = axis , mode = mode ))
85118516
85128517 grad_in = onp .zeros (data_shape , dtype = 'float32' )
@@ -8518,15 +8523,15 @@ def check_output_n_grad(data_shape, idx_shape, axis, mode):
85188523 x .attach_grad ()
85198524 with mx .autograd .record ():
85208525 mx_out = test_take (x , np .array (idx_real ))
8521- same (mx_out .asnumpy (), onp .take (data_real , idx_real , axis = axis , mode = mode ))
8526+ assert same (mx_out .asnumpy (), onp .take (data_real , idx_real , axis = axis , mode = mode ))
85228527
85238528 if axis and axis < 0 :
85248529 axis += len (data_shape )
8525- try :
8530+
8531+ if idx_real .size != 0 :
85268532 for i in onp .nditer (idx_real ):
85278533 grad_helper (grad_in , axis , i , mode )
8528- except :
8529- pass
8534+
85308535
85318536 mx_out .backward ()
85328537 same (x .grad .asnumpy (), grad_in )
@@ -10195,7 +10200,7 @@ def forward(self, cond, x, y):
1019510200 ]
1019610201 flags = [True , False ]
1019710202 for ctype , dtype , shape_pair , hybridize in itertools .product (dtypes , dtypes , shape_configs , flags ):
10198- cond = np .random .uniform (low = 0 , high = 100 , size = shape_pair [0 ], dtype = 'float64' ).astype (ctype )
10203+ cond = np .round ( np . random .uniform (low = 0 , high = 2 , size = shape_pair [0 ], dtype = 'float64' ) ).astype (ctype )
1019910204 x = np .random .uniform (low = 0 , high = 100 , size = shape_pair [1 ], dtype = 'float64' ).astype (dtype )
1020010205 y = np .random .uniform (low = 0 , high = 100 , size = shape_pair [2 ], dtype = 'float64' ).astype (dtype )
1020110206 cond .attach_grad ()
@@ -10206,37 +10211,50 @@ def forward(self, cond, x, y):
1020610211 test_mod .hybridize ()
1020710212 with mx .autograd .record ():
1020810213 ret = test_mod (cond , x , y )
10209- same (ret .asnumpy (), onp .where (cond .asnumpy (), x .asnumpy (), y .asnumpy ()))
10214+
10215+ assert same (ret .asnumpy (), onp .where (cond .asnumpy (), x .asnumpy (), y .asnumpy ()))
1021010216 if dtype in [np .float16 , np .float32 , np .float64 ]:
1021110217 ret .backward ()
10212- same (cond .grad .asnumpy (), onp .zeros (shape_pair [0 ], dtype = ctype ))
10213- same (x .grad .asnumpy (), collapse_sum_like (onp .broadcast_to (cond .asnumpy (), ret .shape ), shape_pair [1 ]))
10218+ assert same (cond .grad .asnumpy (), onp .zeros (shape_pair [0 ], dtype = ctype ))
10219+
10220+ xgrad = x .grad .asnumpy ()
10221+ npgrad = collapse_sum_like ((onp .broadcast_to (cond .asnumpy (), ret .shape ) != 0 ).astype (dtype ), shape_pair [1 ])
10222+ npgrad = npgrad .astype (xgrad .dtype )
10223+ assert same (xgrad , npgrad )
1021410224
1021510225 # check imperative again
1021610226 ret = np .where (cond , x , y )
10217- same (ret .asnumpy (), onp .where (cond .asnumpy (), x .asnumpy (), y .asnumpy ()))
10227+ assert same (ret .asnumpy (), onp .where (cond .asnumpy (), x .asnumpy (), y .asnumpy ()))
1021810228
1021910229 # check scalar case
1022010230 if dtype in [np .float16 , np .float32 , np .float64 ]:
1022110231 # lscalar
1022210232 with mx .autograd .record ():
1022310233 ret_lscalar = np .where (cond , 1 , x )
10224- same (ret .asnumpy (), onp .where (cond .asnumpy (), 1 , x .asnumpy ()))
10234+ assert same (ret_lscalar .asnumpy (), onp .where (cond .asnumpy (), 1 , x .asnumpy ()))
1022510235 ret_lscalar .backward ()
10226- same (x .grad .asnumpy (), 1 - collapse_sum_like (onp .broadcast_to (cond .asnumpy (), ret .shape ), shape_pair [1 ]))
10236+
10237+ xgrad = x .grad .asnumpy ()
10238+ npgrad = collapse_sum_like ((onp .broadcast_to (cond .asnumpy (), ret_lscalar .shape ) == 0 ).astype (dtype ), shape_pair [1 ])
10239+ npgrad = npgrad .astype (xgrad .dtype )
10240+ assert same (xgrad , npgrad )
1022710241 # rscalar
1022810242 with mx .autograd .record ():
1022910243 ret_rscalar = np .where (cond , x , 1 )
10230- same (ret .asnumpy (), onp .where (cond .asnumpy (), x .asnumpy (), 1 ))
10244+ assert same (ret_rscalar .asnumpy (), onp .where (cond .asnumpy (), x .asnumpy (), 1 ))
1023110245 ret_rscalar .backward ()
10232- same (x .grad .asnumpy (), collapse_sum_like (onp .broadcast_to (cond .asnumpy (), ret .shape ), shape_pair [1 ]))
10246+
10247+ xgrad = x .grad .asnumpy ()
10248+ npgrad = collapse_sum_like ((onp .broadcast_to (cond .asnumpy (), ret_rscalar .shape ) != 0 ).astype (dtype ), shape_pair [1 ])
10249+ npgrad = npgrad .astype (xgrad .dtype )
10250+ assert same (xgrad , npgrad )
1023310251
1023410252 # check both scalar case
1023510253 x = onp .random .randint (0 , 100 )
1023610254 y = onp .random .randint (0 , 100 )
1023710255 mx_out = np .where (cond , x , y )
1023810256 np_out = onp .where (cond , x , y )
10239- same (mx_out , np_out )
10257+ assert same (mx_out , np_out )
1024010258
1024110259
1024210260@use_np
0 commit comments