1919from devito .passes .iet .langbase import LangBB
2020from devito .symbolics import (
2121 Byref , DefFunction , FieldFromPointer , IndexedPointer , ListInitializer ,
22- SizeOf , VOID , pow_to_mul , unevaluate
22+ SizeOf , VOID , pow_to_mul , unevaluate , LONG , retrieve_symbols
2323)
2424from devito .tools import as_mapper , as_list , as_tuple , filter_sorted , flatten
2525from devito .types import (
@@ -90,6 +90,17 @@ def __init__(self, rcompile=None, sregistry=None, platform=None, **kwargs):
9090 self .rcompile = rcompile
9191 self .sregistry = sregistry
9292 self .platform = platform
93+ self .index_mode = kwargs .get ('options' , {'index-mode' : 'int32' })['index-mode' ]
94+
95+ def intm (self , nbytes ):
96+ if self .index_mode == 'int64' :
97+ try :
98+ syms = retrieve_symbols (nbytes )
99+ return nbytes .subs ({s : LONG (s ) for s in syms })
100+ except AttributeError :
101+ return LONG (nbytes )
102+ else :
103+ return nbytes
93104
94105 def _alloc_object_on_low_lat_mem (self , site , obj , storage ):
95106 """
@@ -136,7 +147,7 @@ def _alloc_array_on_global_mem(self, site, obj, storage):
136147
137148 # Copy input array into global array
138149 name = self .sregistry .make_name (prefix = 'init_global' )
139- nbytes = SizeOf (obj ._C_typedata )* obj .size
150+ nbytes = SizeOf (obj ._C_typedata )* self . intm ( obj .size )
140151 body = [Definition (src ),
141152 self .langbb ['alloc-global-symbol' ](obj .indexed , src .indexed , nbytes )]
142153 efunc = make_callable (name , body )
@@ -159,7 +170,7 @@ def _alloc_host_array_on_high_bw_mem(self, site, obj, storage, *args):
159170
160171 memptr = VOID (Byref (obj ._C_symbol ), '**' )
161172 alignment = obj ._data_alignment
162- nbytes = SizeOf (obj ._C_typedata )* obj .size
173+ nbytes = SizeOf (obj ._C_typedata )* self . intm ( obj .size )
163174 alloc = self .langbb ['host-alloc' ](memptr , alignment , nbytes )
164175
165176 free = self .langbb ['host-free' ](obj ._C_symbol )
@@ -358,15 +369,15 @@ def _alloc_pointed_array_on_high_bw_mem(self, site, obj, storage):
358369
359370 memptr = VOID (Byref (obj ._C_symbol ), '**' )
360371 alignment = obj ._data_alignment
361- nbytes = SizeOf (obj ._C_typedata , stars = '*' )* obj .dim .symbolic_size
372+ nbytes = SizeOf (obj ._C_typedata , stars = '*' )* self . intm ( obj .dim .symbolic_size )
362373 alloc0 = self .langbb ['host-alloc' ](memptr , alignment , nbytes )
363374
364375 free0 = self .langbb ['host-free' ](obj ._C_symbol )
365376
366377 # The pointee Array
367378 pobj = IndexedPointer (obj ._C_symbol , obj .dim )
368379 memptr = VOID (Byref (pobj ), '**' )
369- nbytes = SizeOf (obj ._C_typedata )* obj .array .size
380+ nbytes = SizeOf (obj ._C_typedata )* self . intm ( obj .array .size )
370381 alloc1 = self .langbb ['host-alloc' ](memptr , alignment , nbytes )
371382
372383 free1 = self .langbb ['host-free' ](pobj )
0 commit comments