Skip to content

Commit e65756d

Browse files
committed
compiler: Polish codegen for Arrays
1 parent baa4d1e commit e65756d

1 file changed

Lines changed: 6 additions & 4 deletions

File tree

devito/passes/iet/definitions.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,9 @@ def _alloc_mapped_array_on_high_bw_mem(self, site, obj, storage, *args):
188188
arity_arg = sizeof_dtypeN / sizeof_dtype1
189189
ndims_param = Symbol(name='ndims', dtype=size_t)
190190
ndims_arg = obj.ndim
191-
shape_param = Array(name=f'{obj.name}_shape', dtype=np.int32,
192-
dimensions=(Dimension(name='d'),), scope='rvalue')
191+
shape_param = Array(name=f'{obj.name}_shape', scope='rvalue',
192+
dtype=np.int32 if obj.is_regular else np.uint64,
193+
dimensions=(Dimension(name='d'),))
193194
shape_arg = ListInitializer(obj.c0.symbolic_shape, dtype=shape_param.dtype)
194195

195196
ffp0 = FieldFromPointer(obj._C_field_data, obj._C_symbol)
@@ -273,8 +274,9 @@ def _alloc_bundle_struct_on_high_bw_mem(self, site, obj, storage):
273274
arity_arg = sizeof_dtypeN / sizeof_dtype1
274275
ndims_param = Symbol(name='ndims', dtype=size_t)
275276
ndims_arg = obj.ndim
276-
shape_param = Array(name=f'{obj.name}_shape', dtype=np.uint64,
277-
dimensions=(Dimension(name='d'),), scope='rvalue')
277+
shape_param = Array(name=f'{obj.name}_shape', scope='rvalue',
278+
dtype=np.int32 if obj.is_regular else np.uint64,
279+
dimensions=(Dimension(name='d'),))
278280
shape_arg = ListInitializer(obj.c0.symbolic_shape, dtype=shape_param.dtype)
279281

280282
ffp1 = FieldFromPointer(obj._C_field_shape, obj._C_symbol)

0 commit comments

Comments
 (0)