@@ -158,38 +158,41 @@ def _empty_like_orderK(X, dt, usm_type=None, dev=None):
158158 return dpt .permute_dims (R , inv_perm )
159159
160160
161- def _empty_like_pair_orderK (X1 , X2 , dt , usm_type , dev ):
161+ def _empty_like_pair_orderK (X1 , X2 , dt , res_shape , usm_type , dev ):
162162 if not isinstance (X1 , dpt .usm_ndarray ):
163163 raise TypeError (f"Expected usm_ndarray, got { type (X1 )} " )
164164 if not isinstance (X2 , dpt .usm_ndarray ):
165165 raise TypeError (f"Expected usm_ndarray, got { type (X2 )} " )
166166 nd1 = X1 .ndim
167167 nd2 = X2 .ndim
168- if nd1 > nd2 :
168+ if nd1 > nd2 and X1 . shape == res_shape :
169169 return _empty_like_orderK (X1 , dt , usm_type , dev )
170- elif nd1 < nd2 :
170+ elif nd1 < nd2 and X2 . shape == res_shape :
171171 return _empty_like_orderK (X2 , dt , usm_type , dev )
172172 fl1 = X1 .flags
173173 fl2 = X2 .flags
174174 if fl1 ["C" ] or fl2 ["C" ]:
175- return dpt .empty_like (
176- X1 , dtype = dt , usm_type = usm_type , device = dev , order = "C"
175+ return dpt .empty (
176+ res_shape , dtype = dt , usm_type = usm_type , device = dev , order = "C"
177177 )
178178 if fl1 ["F" ] and fl2 ["F" ]:
179- return dpt .empty_like (
180- X1 , dtype = dt , usm_type = usm_type , device = dev , order = "F"
179+ return dpt .empty (
180+ res_shape , dtype = dt , usm_type = usm_type , device = dev , order = "F"
181181 )
182182 st1 = list (X1 .strides )
183183 st2 = list (X2 .strides )
184+ max_ndim = max (nd1 , nd2 )
185+ st1 += [0 ] * (max_ndim - len (st1 ))
186+ st2 += [0 ] * (max_ndim - len (st2 ))
184187 perm = sorted (
185- range (nd1 ),
188+ range (max_ndim ),
186189 key = lambda d : (builtins .abs (st1 [d ]), builtins .abs (st2 [d ])),
187190 reverse = True ,
188191 )
189- inv_perm = sorted (range (nd1 ), key = lambda i : perm [i ])
192+ inv_perm = sorted (range (max_ndim ), key = lambda i : perm [i ])
190193 st1_sorted = [st1 [i ] for i in perm ]
191194 st2_sorted = [st2 [i ] for i in perm ]
192- sh = X1 . shape
195+ sh = res_shape
193196 sh_sorted = tuple (sh [i ] for i in perm )
194197 R = dpt .empty (sh_sorted , dtype = dt , usm_type = usm_type , device = dev , order = "C" )
195198 if max (min (st1_sorted ), min (st2_sorted )) < 0 :
0 commit comments