@@ -137,14 +137,12 @@ def permute_dims(X, axes):
137137 """
138138 if not isinstance (X , dpt .usm_ndarray ):
139139 raise TypeError (f"Expected usm_ndarray type, got { type (X )} ." )
140- if not isinstance (axes , (tuple , list )):
141- axes = (axes ,)
140+ axes = normalize_axis_tuple (axes , X .ndim , "axes" )
142141 if not X .ndim == len (axes ):
143142 raise ValueError (
144143 "The length of the passed axes does not match "
145144 "to the number of usm_ndarray dimensions."
146145 )
147- axes = normalize_axis_tuple (axes , X .ndim , "axes" )
148146 newstrides = tuple (X .strides [i ] for i in axes )
149147 newshape = tuple (X .shape [i ] for i in axes )
150148 return dpt .usm_ndarray (
@@ -187,7 +185,8 @@ def expand_dims(X, axis):
187185 """
188186 if not isinstance (X , dpt .usm_ndarray ):
189187 raise TypeError (f"Expected usm_ndarray type, got { type (X )} ." )
190- if not isinstance (axis , (tuple , list )):
188+
189+ if type (axis ) not in (tuple , list ):
191190 axis = (axis ,)
192191
193192 out_ndim = len (axis ) + X .ndim
@@ -224,8 +223,6 @@ def squeeze(X, axis=None):
224223 raise TypeError (f"Expected usm_ndarray type, got { type (X )} ." )
225224 X_shape = X .shape
226225 if axis is not None :
227- if not isinstance (axis , (tuple , list )):
228- axis = (axis ,)
229226 axis = normalize_axis_tuple (axis , X .ndim if X .ndim != 0 else X .ndim + 1 )
230227 new_shape = []
231228 for i , x in enumerate (X_shape ):
@@ -819,12 +816,6 @@ def moveaxis(X, source, destination):
819816 if not isinstance (X , dpt .usm_ndarray ):
820817 raise TypeError (f"Expected usm_ndarray type, got { type (X )} ." )
821818
822- if not isinstance (source , (tuple , list )):
823- source = (source ,)
824-
825- if not isinstance (destination , (tuple , list )):
826- destination = (destination ,)
827-
828819 source = normalize_axis_tuple (source , X .ndim , "source" )
829820 destination = normalize_axis_tuple (destination , X .ndim , "destination" )
830821
0 commit comments