@@ -229,35 +229,50 @@ def _alloc_bundle_struct_on_high_bw_mem(self, site, obj, storage):
229229 """
230230 decl = Definition (obj )
231231
232+ arity_param = Symbol (name = 'arity' , dtype = size_t )
233+ arity_arg = SizeOf (obj .indexed ._C_typedata )
234+ ffp1 = FieldFromPointer (obj ._C_field_shape , obj ._C_symbol )
235+ ffp2 = FieldFromPointer (obj ._C_field_size , obj ._C_symbol )
236+ ffp3 = FieldFromPointer (obj ._C_field_nbytes , obj ._C_symbol )
237+
232238 # Allocate the Bundle struct
233239 memptr = VOID (Byref (obj ._C_symbol ), '**' )
234240 alignment = obj ._data_alignment
235241 nbytes = SizeOf (obj ._C_typedata )
236- alloc = self .langbb ['host-alloc' ](memptr , alignment , nbytes )
242+ alloc0 = self .langbb ['host-alloc' ](memptr , alignment , nbytes )
237243
238- nbytes_param = Symbol (name = 'nbytes' , dtype = np .uint64 , is_const = True )
239- nbytes_arg = SizeOf (obj .indexed ._C_typedata )* obj .size
244+ # Allocate the shape array
245+ memptr = VOID (Byref (ffp1 ), '**' )
246+ nbytes = SizeOf (obj ._C_size_type )* obj .ndim
247+ alloc1 = self .langbb ['host-alloc' ](memptr , alignment , nbytes )
240248
241249 # Initialize the Bundle struct
242- ffp1 = FieldFromPointer ( obj . _C_field_nbytes , obj . _C_symbol )
243- init0 = DummyExpr ( ffp1 , nbytes_param )
244- ffp2 = FieldFromPointer ( obj . _C_field_size , obj ._C_symbol )
245- init1 = DummyExpr (ffp2 , 0 )
250+ init = [ * [ DummyExpr ( IndexedPointer ( ffp1 , i ), s )
251+ for i , s in enumerate ( obj . c0 . symbolic_shape )],
252+ DummyExpr ( ffp2 , obj .size ),
253+ DummyExpr (ffp3 , ffp2 * arity_param )]
246254
247- free = self .langbb ['host-free' ](obj ._C_symbol )
255+ # Free all of the allocated data
256+ frees = [self .langbb ['host-free' ](ffp1 ),
257+ self .langbb ['host-free' ](obj ._C_symbol )]
248258
249259 ret = Return (obj ._C_symbol )
250260
251261 # Wrap everything in a Callable so that we can reuse the same code
252262 # for equivalent Bundle structs
253263 name = self .sregistry .make_name (prefix = 'alloc' )
254- body = (decl , alloc , init0 , init1 , ret )
264+ body = (decl , alloc0 , alloc1 , * init , ret )
255265 efunc0 = make_callable (name , body , retval = obj )
256266 args = list (efunc0 .parameters )
257- args [args .index (nbytes_param )] = nbytes_arg
267+ args [args .index (arity_param )] = arity_arg
258268 alloc = Call (name , args , retobj = obj )
259269
260- storage .update (obj , site , allocs = alloc , frees = free , efuncs = efunc0 )
270+ # Same story for the frees
271+ name = self .sregistry .make_name (prefix = 'free' )
272+ efunc1 = make_callable (name , frees )
273+ free = Call (name , efunc1 .parameters )
274+
275+ storage .update (obj , site , allocs = alloc , frees = free , efuncs = (efunc0 , efunc1 ))
261276
262277 def _alloc_object_array_on_low_lat_mem (self , site , obj , storage ):
263278 """
0 commit comments