Skip to content

Commit e6806ec

Browse files
committed
WIP WIP WIP [MASSIVE OVERHAUL IN PROGRESS, STILL BREAKING]
1 parent 2e52e88 commit e6806ec

2 files changed

Lines changed: 42 additions & 33 deletions

File tree

devito/passes/iet/definitions.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -168,49 +168,55 @@ def _alloc_mapped_array_on_high_bw_mem(self, site, obj, storage, *args):
168168
"""
169169
decl = Definition(obj)
170170

171+
# Symbols/pointers
172+
size = Symbol(name='size', dtype=np.uint64, is_const=True)
173+
ffp0 = FieldFromPointer(obj._C_field_data, obj._C_symbol)
174+
ffp1 = FieldFromPointer(obj._C_field_shape, obj._C_symbol)
175+
ffp2 = FieldFromPointer(obj._C_field_size, obj._C_symbol)
176+
ffp3 = FieldFromPointer(obj._C_field_nbytes, obj._C_symbol)
177+
171178
# Allocate the Array struct
172179
memptr = VOID(Byref(obj._C_symbol), '**')
173180
alignment = obj._data_alignment
174181
nbytes = SizeOf(obj._C_typedata)
175-
allocs = [self.langbb['host-alloc'](memptr, alignment, nbytes)]
182+
alloc0 = self.langbb['host-alloc'](memptr, alignment, nbytes)
176183

177-
nbytes_param = Symbol(name='nbytes', dtype=np.uint64, is_const=True)
178-
nbytes_arg = SizeOf(obj.indexed._C_typedata)*obj.size
184+
# Allocate the shape array
185+
memptr = VOID(Byref(ffp1), '**')
186+
nbytes = SizeOf(obj._C_size_type)*obj.ndim
187+
alloc1 = self.langbb['host-alloc'](memptr, alignment, nbytes)
188+
189+
# Initialize the Array struct
190+
init = [*[DummyExpr(IndexedPointer(ffp1, i), s)
191+
for i, s in enumerate(obj.symbolic_shape)],
192+
DummyExpr(size, obj.size, init=True),
193+
DummyExpr(ffp2, size),
194+
DummyExpr(ffp3, size*SizeOf(obj.indexed._C_typedata))]
179195

180196
# Allocate the underlying host data
181-
ffp0 = FieldFromPointer(obj._C_field_data, obj._C_symbol)
182197
memptr = VOID(Byref(ffp0), '**')
183-
allocs.append(self.langbb['host-alloc-pin'](memptr, alignment, nbytes_param))
184-
185-
# Initialize the Array struct
186-
ffp1 = FieldFromPointer(obj._C_field_nbytes, obj._C_symbol)
187-
init0 = DummyExpr(ffp1, nbytes_param)
188-
ffp2 = FieldFromPointer(obj._C_field_size, obj._C_symbol)
189-
init1 = DummyExpr(ffp2, 0)
198+
alloc2 = self.langbb['host-alloc-pin'](memptr, alignment, ffp3)
190199

200+
# Free all of the allocated data
191201
frees = [self.langbb['host-free-pin'](ffp0),
202+
self.langbb['host-free'](ffp1),
192203
self.langbb['host-free'](obj._C_symbol)]
193204

194205
# Allocate the underlying device data, if required by the backend
195-
alloc, free = self._make_dmap_allocfree(obj, nbytes_param)
196-
197-
# Chain together all allocs and frees
198-
allocs = as_tuple(allocs) + as_tuple(alloc)
199-
frees = as_tuple(free) + as_tuple(frees)
206+
alloc_dmap, free_dmap = self._make_dmap_allocfree(obj, ffp3)
200207

201208
ret = Return(obj._C_symbol)
202209

203210
# Wrap everything in a Callable so that we can reuse the same code
204211
# for equivalent Array structs
205212
name = self.sregistry.make_name(prefix='alloc')
206-
body = (decl, *allocs, init0, init1, ret)
213+
body = (decl, alloc0, alloc1, *init, alloc2, *as_tuple(alloc_dmap), ret)
207214
efunc0 = make_callable(name, body, retval=obj)
208-
args = list(efunc0.parameters)
209-
args[args.index(nbytes_param)] = nbytes_arg
210-
alloc = Call(name, args, retobj=obj)
215+
alloc = Call(name, efunc0.parameters, retobj=obj)
211216

212217
# Same story for the frees
213218
name = self.sregistry.make_name(prefix='free')
219+
frees = as_tuple(free_dmap) + as_tuple(frees)
214220
efunc1 = make_callable(name, frees)
215221
free = Call(name, efunc1.parameters)
216222

devito/types/array.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ctypes import POINTER, Structure, c_void_p, c_ulong, c_uint64
1+
from ctypes import POINTER, Structure, c_void_p, c_uint64
22
from functools import cached_property
33

44
import numpy as np
@@ -202,19 +202,27 @@ def _make_pointer(self, dim):
202202
return PointerArray(name='p%s' % self.name, dimensions=dim, array=self)
203203

204204

205-
class ArrayMapped(Array):
205+
class MappedArrayMixin:
206206

207207
_C_structname = 'array'
208208
_C_field_data = 'data'
209209
_C_field_dmap = 'dmap'
210-
_C_field_nbytes = 'nbytes'
210+
_C_field_shape = 'shape'
211211
_C_field_size = 'size'
212+
_C_field_nbytes = 'nbytes'
213+
214+
_C_size_type = c_uint64
212215

213216
_C_ctype = POINTER(type(_C_structname, (Structure,),
214217
{'_fields_': [(_C_field_data, c_restrict_void_p),
215-
(_C_field_nbytes, c_ulong),
216218
(_C_field_dmap, c_void_p),
217-
(_C_field_size, c_uint64)]}))
219+
(_C_field_shape, POINTER(_C_size_type)),
220+
(_C_field_size, _C_size_type),
221+
(_C_field_nbytes, _C_size_type)]}))
222+
223+
224+
class ArrayMapped(MappedArrayMixin, Array):
225+
pass
218226

219227

220228
class ArrayObject(ArrayBasic):
@@ -343,7 +351,7 @@ def array(self):
343351
return self._array
344352

345353

346-
class Bundle(ArrayBasic):
354+
class Bundle(MappedArrayMixin, ArrayBasic):
347355

348356
"""
349357
Tensor symbol representing an unrolled vector of AbstractFunctions.
@@ -490,17 +498,12 @@ def __getitem__(self, index):
490498
raise ValueError("Expected %d or %d indices, got %d instead"
491499
% (self.ndim, self.ndim + 1, len(index)))
492500

493-
_C_structname = ArrayMapped._C_structname
494-
_C_field_data = ArrayMapped._C_field_data
495-
_C_field_nbytes = ArrayMapped._C_field_nbytes
496-
_C_field_dmap = ArrayMapped._C_field_dmap
497-
_C_field_size = ArrayMapped._C_field_size
498-
499501
@property
500502
def _C_ctype(self):
501503
if self._mem_mapped:
502-
return ArrayMapped._C_ctype
504+
return super()._C_ctype
503505
else:
506+
#TODO DROP???
504507
return POINTER(dtype_to_ctype(self.dtype))
505508

506509

0 commit comments

Comments
 (0)