@@ -168,49 +168,55 @@ def _alloc_mapped_array_on_high_bw_mem(self, site, obj, storage, *args):
168168 """
169169 decl = Definition (obj )
170170
171+ # Symbols/pointers
172+ size = Symbol (name = 'size' , dtype = np .uint64 , is_const = True )
173+ ffp0 = FieldFromPointer (obj ._C_field_data , obj ._C_symbol )
174+ ffp1 = FieldFromPointer (obj ._C_field_shape , obj ._C_symbol )
175+ ffp2 = FieldFromPointer (obj ._C_field_size , obj ._C_symbol )
176+ ffp3 = FieldFromPointer (obj ._C_field_nbytes , obj ._C_symbol )
177+
171178 # Allocate the Array struct
172179 memptr = VOID (Byref (obj ._C_symbol ), '**' )
173180 alignment = obj ._data_alignment
174181 nbytes = SizeOf (obj ._C_typedata )
175- allocs = [ self .langbb ['host-alloc' ](memptr , alignment , nbytes )]
182+ alloc0 = self .langbb ['host-alloc' ](memptr , alignment , nbytes )
176183
177- nbytes_param = Symbol (name = 'nbytes' , dtype = np .uint64 , is_const = True )
178- nbytes_arg = SizeOf (obj .indexed ._C_typedata )* obj .size
184+ # Allocate the shape array
185+ memptr = VOID (Byref (ffp1 ), '**' )
186+ nbytes = SizeOf (obj ._C_size_type )* obj .ndim
187+ alloc1 = self .langbb ['host-alloc' ](memptr , alignment , nbytes )
188+
189+ # Initialize the Array struct
190+ init = [* [DummyExpr (IndexedPointer (ffp1 , i ), s )
191+ for i , s in enumerate (obj .symbolic_shape )],
192+ DummyExpr (size , obj .size , init = True ),
193+ DummyExpr (ffp2 , size ),
194+ DummyExpr (ffp3 , size * SizeOf (obj .indexed ._C_typedata ))]
179195
180196 # Allocate the underlying host data
181- ffp0 = FieldFromPointer (obj ._C_field_data , obj ._C_symbol )
182197 memptr = VOID (Byref (ffp0 ), '**' )
183- allocs .append (self .langbb ['host-alloc-pin' ](memptr , alignment , nbytes_param ))
184-
185- # Initialize the Array struct
186- ffp1 = FieldFromPointer (obj ._C_field_nbytes , obj ._C_symbol )
187- init0 = DummyExpr (ffp1 , nbytes_param )
188- ffp2 = FieldFromPointer (obj ._C_field_size , obj ._C_symbol )
189- init1 = DummyExpr (ffp2 , 0 )
198+ alloc2 = self .langbb ['host-alloc-pin' ](memptr , alignment , ffp3 )
190199
200+ # Free all of the allocated data
191201 frees = [self .langbb ['host-free-pin' ](ffp0 ),
202+ self .langbb ['host-free' ](ffp1 ),
192203 self .langbb ['host-free' ](obj ._C_symbol )]
193204
194205 # Allocate the underlying device data, if required by the backend
195- alloc , free = self ._make_dmap_allocfree (obj , nbytes_param )
196-
197- # Chain together all allocs and frees
198- allocs = as_tuple (allocs ) + as_tuple (alloc )
199- frees = as_tuple (free ) + as_tuple (frees )
206+ alloc_dmap , free_dmap = self ._make_dmap_allocfree (obj , ffp3 )
200207
201208 ret = Return (obj ._C_symbol )
202209
203210 # Wrap everything in a Callable so that we can reuse the same code
204211 # for equivalent Array structs
205212 name = self .sregistry .make_name (prefix = 'alloc' )
206- body = (decl , * allocs , init0 , init1 , ret )
213+ body = (decl , alloc0 , alloc1 , * init , alloc2 , * as_tuple ( alloc_dmap ) , ret )
207214 efunc0 = make_callable (name , body , retval = obj )
208- args = list (efunc0 .parameters )
209- args [args .index (nbytes_param )] = nbytes_arg
210- alloc = Call (name , args , retobj = obj )
215+ alloc = Call (name , efunc0 .parameters , retobj = obj )
211216
212217 # Same story for the frees
213218 name = self .sregistry .make_name (prefix = 'free' )
219+ frees = as_tuple (free_dmap ) + as_tuple (frees )
214220 efunc1 = make_callable (name , frees )
215221 free = Call (name , efunc1 .parameters )
216222
0 commit comments