@@ -188,8 +188,9 @@ def _alloc_mapped_array_on_high_bw_mem(self, site, obj, storage, *args):
188188 arity_arg = sizeof_dtypeN / sizeof_dtype1
189189 ndims_param = Symbol (name = 'ndims' , dtype = size_t )
190190 ndims_arg = obj .ndim
191- shape_param = Array (name = f'{ obj .name } _shape' , dtype = np .int32 ,
192- dimensions = (Dimension (name = 'd' ),), scope = 'rvalue' )
191+ shape_param = Array (name = f'{ obj .name } _shape' , scope = 'rvalue' ,
192+ dtype = np .int32 if obj .is_regular else np .uint64 ,
193+ dimensions = (Dimension (name = 'd' ),))
193194 shape_arg = ListInitializer (obj .c0 .symbolic_shape , dtype = shape_param .dtype )
194195
195196 ffp0 = FieldFromPointer (obj ._C_field_data , obj ._C_symbol )
@@ -273,8 +274,9 @@ def _alloc_bundle_struct_on_high_bw_mem(self, site, obj, storage):
273274 arity_arg = sizeof_dtypeN / sizeof_dtype1
274275 ndims_param = Symbol (name = 'ndims' , dtype = size_t )
275276 ndims_arg = obj .ndim
276- shape_param = Array (name = f'{ obj .name } _shape' , dtype = np .uint64 ,
277- dimensions = (Dimension (name = 'd' ),), scope = 'rvalue' )
277+ shape_param = Array (name = f'{ obj .name } _shape' , scope = 'rvalue' ,
278+ dtype = np .int32 if obj .is_regular else np .uint64 ,
279+ dimensions = (Dimension (name = 'd' ),))
278280 shape_arg = ListInitializer (obj .c0 .symbolic_shape , dtype = shape_param .dtype )
279281
280282 ffp1 = FieldFromPointer (obj ._C_field_shape , obj ._C_symbol )
0 commit comments