Skip to content

Commit 8e73386

Browse files
committed
compiler: fix temp allocation with long index-mode
1 parent 598fd62 commit 8e73386

6 files changed

Lines changed: 42 additions & 14 deletions

File tree

devito/ir/cgen/printer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def _print_UnaryOp(self, expr, op=None, parenthesize=False):
370370

371371
def _print_Cast(self, expr):
372372
cast = f'({self._print(expr._C_ctype)}{self._print(expr.stars)})'
373-
return self._print_UnaryOp(expr, op=cast)
373+
return self._print_UnaryOp(expr, op=cast, parenthesize=True)
374374

375375
def _print_ComponentAccess(self, expr):
376376
return f"{self._print(expr.base)}.{expr.sindex}"

devito/passes/iet/definitions.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from devito.passes.iet.langbase import LangBB
2020
from devito.symbolics import (
2121
Byref, DefFunction, FieldFromPointer, IndexedPointer, ListInitializer,
22-
SizeOf, VOID, pow_to_mul, unevaluate
22+
SizeOf, VOID, pow_to_mul, unevaluate, LONG, retrieve_symbols
2323
)
2424
from devito.tools import as_mapper, as_list, as_tuple, filter_sorted, flatten
2525
from devito.types import (
@@ -90,6 +90,17 @@ def __init__(self, rcompile=None, sregistry=None, platform=None, **kwargs):
9090
self.rcompile = rcompile
9191
self.sregistry = sregistry
9292
self.platform = platform
93+
self.index_mode = kwargs.get('options', {'index-mode': 'int32'})['index-mode']
94+
95+
def intm(self, nbytes):
96+
if self.index_mode == 'int64':
97+
try:
98+
syms = retrieve_symbols(nbytes)
99+
return nbytes.subs({s: LONG(s) for s in syms})
100+
except AttributeError:
101+
return LONG(nbytes)
102+
else:
103+
return nbytes
93104

94105
def _alloc_object_on_low_lat_mem(self, site, obj, storage):
95106
"""
@@ -136,7 +147,7 @@ def _alloc_array_on_global_mem(self, site, obj, storage):
136147

137148
# Copy input array into global array
138149
name = self.sregistry.make_name(prefix='init_global')
139-
nbytes = SizeOf(obj._C_typedata)*obj.size
150+
nbytes = SizeOf(obj._C_typedata)*self.intm(obj.size)
140151
body = [Definition(src),
141152
self.langbb['alloc-global-symbol'](obj.indexed, src.indexed, nbytes)]
142153
efunc = make_callable(name, body)
@@ -159,7 +170,7 @@ def _alloc_host_array_on_high_bw_mem(self, site, obj, storage, *args):
159170

160171
memptr = VOID(Byref(obj._C_symbol), '**')
161172
alignment = obj._data_alignment
162-
nbytes = SizeOf(obj._C_typedata)*obj.size
173+
nbytes = SizeOf(obj._C_typedata)*self.intm(obj.size)
163174
alloc = self.langbb['host-alloc'](memptr, alignment, nbytes)
164175

165176
free = self.langbb['host-free'](obj._C_symbol)
@@ -358,15 +369,15 @@ def _alloc_pointed_array_on_high_bw_mem(self, site, obj, storage):
358369

359370
memptr = VOID(Byref(obj._C_symbol), '**')
360371
alignment = obj._data_alignment
361-
nbytes = SizeOf(obj._C_typedata, stars='*')*obj.dim.symbolic_size
372+
nbytes = SizeOf(obj._C_typedata, stars='*')*self.intm(obj.dim.symbolic_size)
362373
alloc0 = self.langbb['host-alloc'](memptr, alignment, nbytes)
363374

364375
free0 = self.langbb['host-free'](obj._C_symbol)
365376

366377
# The pointee Array
367378
pobj = IndexedPointer(obj._C_symbol, obj.dim)
368379
memptr = VOID(Byref(pobj), '**')
369-
nbytes = SizeOf(obj._C_typedata)*obj.array.size
380+
nbytes = SizeOf(obj._C_typedata)*self.intm(obj.array.size)
370381
alloc1 = self.langbb['host-alloc'](memptr, alignment, nbytes)
371382

372383
free1 = self.langbb['host-free'](pobj)

devito/symbolics/extended_dtypes.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from devito.tools.dtypes_lowering import dtype_mapper
88

99
__all__ = ['cast', 'CustomType', 'limits_mapper', 'INT', 'FLOAT', 'BaseCast', # noqa
10-
'DOUBLE', 'VOID', 'NoDeclStruct', 'c_complex', 'c_double_complex']
10+
'DOUBLE', 'VOID', 'NoDeclStruct', 'c_complex', 'c_double_complex',
11+
'LONG']
1112

1213

1314
limits_mapper = {
@@ -72,6 +73,7 @@ def cast(casttype, stars=None):
7273

7374
ULONG = cast(np.uint64)
7475
UINTP = cast(np.uint32, '*')
76+
LONG = cast(np.int64)
7577

7678

7779
# Standard ones, needed as class for e.g. single dispatch

tests/test_iet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ class SpecialObject2(LocalObject):
345345
assert 'bar<int,float>& obj2;' in str(iet)
346346
assert 'dummy obj3 = meh;' in str(iet)
347347
assert 'dummy obj4(1,2) = meh;' in str(iet)
348-
assert 'dummy obj5(ceil(1.0F/(float)s)) = meh;' in str(iet)
348+
assert 'dummy obj5(ceil(1.0F/(float)(s))) = meh;' in str(iet)
349349
assert 'float obj6 = meh;' in str(iet)
350350

351351

tests/test_linearize.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import scipy.sparse
44

55
from devito import (Grid, Function, TimeFunction, SparseTimeFunction, Operator, Eq,
6-
Inc, MatrixSparseTimeFunction, sin, switchconfig)
6+
Inc, MatrixSparseTimeFunction, sin, switchconfig, configuration)
77
from devito.ir import Call, Callable, DummyExpr, Expression, FindNodes, SymbolRegistry
88
from devito.passes import Graph, linearize, generate_macros
99
from devito.types import Array, Bundle, DefaultDimension
@@ -640,3 +640,18 @@ def _test_different_dtype():
640640
assert "L0(x,y) f[(x)*y_stride0 + (y)]" in str(op1)
641641

642642
_test_different_dtype()
643+
644+
645+
@pytest.mark.parametrize('order', [2, 4])
646+
def test_int64_array(order):
647+
648+
grid = Grid(shape=(4, 4))
649+
f = Function(name='f', grid=grid, space_order=order)
650+
651+
a = Array(name='a', dimensions=grid.dimensions, shape=grid.shape,
652+
halo=f.halo)
653+
654+
eqs = [Eq(f, a.indexify() + 1)]
655+
op = Operator(eqs, opt=('advanced', {'linearize': True, 'index-mode': 'int64'}))
656+
long = 'static_cast<long>' if 'CXX' in configuration['language'] else '(long)'
657+
assert f'({2*order} + {long}(y_size))*({2*order} + {long}(x_size)))' in str(op)

tests/test_symbolics.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -426,10 +426,10 @@ class BarCast(BaseCast):
426426
_dtype = 'bar'
427427

428428
v = BarCast(s, '**')
429-
assert ccode(v) == '(bar**)s'
429+
assert ccode(v) == '(bar**)(s)'
430430

431431
# Reconstruction
432-
assert ccode(v.func(*v.args)) == '(bar**)s'
432+
assert ccode(v.func(*v.args)) == '(bar**)(s)'
433433

434434
v1 = BarCast(s, '****')
435435
assert v != v1
@@ -442,19 +442,19 @@ def test_str_cast():
442442
assert not v.stars
443443
assert v.dtype == 'foo'
444444
assert v._op == '(foo)'
445-
assert ccode(v) == '(foo)s'
445+
assert ccode(v) == '(foo)(s)'
446446

447447
v = Cast(s, 'foo*')
448448
assert v.stars == '*'
449449
assert v.dtype == 'foo'
450450
assert v._op == '(foo*)'
451-
assert ccode(v) == '(foo*)s'
451+
assert ccode(v) == '(foo*)(s)'
452452

453453
v = Cast(s, 'foo **')
454454
assert v.stars == '**'
455455
assert v.dtype == 'foo'
456456
assert v._op == '(foo**)'
457-
assert ccode(v) == '(foo**)s'
457+
assert ccode(v) == '(foo**)(s)'
458458

459459

460460
def test_findexed():

0 commit comments

Comments
 (0)