Skip to content

Commit 17ca3b9

Browse files
authored
Merge pull request #2567 from devitocodes/fix-reconstruction
compiler: Fix reconstruction/pickling of low level objects
2 parents ad99102 + cb43942 commit 17ca3b9

6 files changed

Lines changed: 65 additions & 33 deletions

File tree

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
1-
import sympy
2-
31
from devito.ir import cluster_pass
4-
from devito.symbolics import reuse_if_untouched, q_leaf
5-
from devito.symbolics.unevaluation import Add, Mul, Pow
2+
from devito.symbolics import unevaluate as _unevaluate
63

74
__all__ = ['unevaluate']
85

@@ -12,22 +9,3 @@ def unevaluate(cluster):
129
exprs = [_unevaluate(e) for e in cluster.exprs]
1310

1411
return cluster.rebuild(exprs=exprs)
15-
16-
17-
mapper = {
18-
sympy.Add: Add,
19-
sympy.Mul: Mul,
20-
sympy.Pow: Pow
21-
}
22-
23-
24-
def _unevaluate(expr):
25-
if q_leaf(expr):
26-
return expr
27-
28-
args = [_unevaluate(a) for a in expr.args]
29-
30-
try:
31-
return mapper[expr.func](*args)
32-
except KeyError:
33-
return reuse_if_untouched(expr, args)

devito/passes/iet/definitions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from devito.passes.iet.engine import iet_pass
1717
from devito.passes.iet.langbase import LangBB
1818
from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer,
19-
SizeOf, VOID, pow_to_mul)
19+
SizeOf, VOID, pow_to_mul, unevaluate)
2020
from devito.tools import as_mapper, as_list, as_tuple, filter_sorted, flatten
2121
from devito.types import (Array, ComponentAccess, CustomDimension, DeviceMap,
2222
DeviceRM, Eq, Symbol)
@@ -119,7 +119,7 @@ def _alloc_array_on_global_mem(self, site, obj, storage):
119119

120120
# Create input array
121121
name = '%s_init' % obj.name
122-
initvalue = np.array([pow_to_mul(i) for i in obj.initvalue])
122+
initvalue = np.array([unevaluate(pow_to_mul(i)) for i in obj.initvalue])
123123
src = Array(name=name, dtype=obj.dtype, dimensions=obj.dimensions,
124124
space='host', scope='stack', initvalue=initvalue)
125125

devito/symbolics/manipulation.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
from devito.symbolics.extended_sympy import DefFunction, rfunc
1414
from devito.symbolics.queries import q_leaf
1515
from devito.symbolics.search import retrieve_indexed, retrieve_functions
16-
from devito.symbolics.unevaluation import Mul as UMul
16+
from devito.symbolics.unevaluation import (
17+
Add as UnevalAdd, Mul as UnevalMul, Pow as UnevalPow
18+
)
1719
from devito.tools import as_list, as_tuple, flatten, split, transitive_closure
1820
from devito.types.basic import Basic, Indexed
1921
from devito.types.array import ComponentAccess
@@ -22,7 +24,7 @@
2224

2325
__all__ = ['xreplace_indices', 'pow_to_mul', 'indexify', 'subs_op_args',
2426
'normalize_args', 'uxreplace', 'Uxmapper', 'subs_if_composite',
25-
'reuse_if_untouched', 'evalrel', 'flatten_args']
27+
'reuse_if_untouched', 'evalrel', 'flatten_args', 'unevaluate']
2628

2729

2830
def uxreplace(expr, rule):
@@ -338,7 +340,7 @@ def pow_to_mul(expr):
338340
# but at least we traverse the base looking for other Pows
339341
return expr.func(pow_to_mul(base), exp, evaluate=False)
340342
elif exp > 0:
341-
return UMul(*[pow_to_mul(base)]*int(exp), evaluate=False)
343+
return UnevalMul(*[pow_to_mul(base)]*int(exp), evaluate=False)
342344
elif exp < 0:
343345
# Reciprocal powers become inverse of the negative power
344346
# for example Pow(expr, -2) becomes Pow(expr * expr, -1)
@@ -502,3 +504,18 @@ def evalrel(func=min, input=None, assumptions=None):
502504
except TypeError:
503505
pass
504506
return rfunc(func, *input)
507+
508+
509+
uneval_mapper = {Add: UnevalAdd, Mul: UnevalMul, Pow: UnevalPow}
510+
511+
512+
def unevaluate(expr):
513+
if q_leaf(expr):
514+
return expr
515+
516+
args = [unevaluate(a) for a in expr.args]
517+
518+
try:
519+
return uneval_mapper[expr.func](*args)
520+
except KeyError:
521+
return reuse_if_untouched(expr, args)

devito/types/array.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
from sympy import Expr, cacheit
66

7-
from devito.tools import (Reconstructable, as_tuple, c_restrict_void_p,
7+
from devito.tools import (Pickable, as_tuple, c_restrict_void_p,
88
dtype_to_ctype, dtypes_vector_mapper, is_integer)
99
from devito.types.basic import AbstractFunction, LocalType
1010
from devito.types.utils import CtypesFactory, DimensionTuple
@@ -518,10 +518,11 @@ def handles(self):
518518
return self.components
519519

520520

521-
class ComponentAccess(Expr, Reconstructable):
521+
class ComponentAccess(Expr, Pickable):
522522

523523
_component_names = ('x', 'y', 'z', 'w')
524524

525+
__rargs__ = ('arg',)
525526
__rkwargs__ = ('index',)
526527

527528
def __new__(cls, arg, index=0, **kwargs):
@@ -543,7 +544,7 @@ def __str__(self):
543544

544545
__repr__ = __str__
545546

546-
func = Reconstructable._rebuild
547+
func = Pickable._rebuild
547548

548549
def _sympystr(self, printer):
549550
return str(self)
@@ -552,6 +553,10 @@ def _sympystr(self, printer):
552553
def base(self):
553554
return self.args[0]
554555

556+
@property
557+
def arg(self):
558+
return self.base
559+
555560
@property
556561
def index(self):
557562
return self._index

tests/test_pickle.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import ctypes
22
import pickle as pickle0
3-
import cloudpickle as pickle1
43

4+
import cloudpickle as pickle1
55
import pytest
66
import numpy as np
77
from sympy import Symbol
@@ -19,7 +19,7 @@
1919
from devito.types import (Array, CustomDimension, Symbol as dSymbol, Scalar,
2020
PointerArray, Lock, PThreadArray, SharedData, Timer,
2121
DeviceID, NPThreads, ThreadID, TempFunction, Indirection,
22-
FIndexed)
22+
FIndexed, ComponentAccess)
2323
from devito.types.basic import BoundSymbol, AbstractSymbol
2424
from devito.tools import EnrichedTuple
2525
from devito.symbolics import (IntDiv, ListInitializer, FieldFromPointer,
@@ -416,6 +416,20 @@ def test_findexed(self, pickle):
416416
assert new_fi.indices == (x+1, y, z-2)
417417
assert new_fi.strides_map == fi.strides_map
418418

419+
def test_component_access(self, pickle):
420+
grid = Grid(shape=(3, 3, 3))
421+
x, y, z = grid.dimensions
422+
423+
f = Function(name='f', grid=grid)
424+
425+
ca = ComponentAccess(f.indexify(), 1)
426+
427+
pkl_ca = pickle.dumps(ca)
428+
new_ca = pickle.loads(pkl_ca)
429+
430+
assert new_ca.index == 1
431+
assert new_ca.function.name == f.name
432+
419433
def test_symbolics(self, pickle):
420434
a = Symbol('a')
421435

tests/test_symbolics.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,24 @@ def test_findexed():
483483
assert new_fi.strides_map == strides_map
484484

485485

486+
def test_component_access():
487+
grid = Grid(shape=(3, 3, 3))
488+
x, y, z = grid.dimensions
489+
490+
f = Function(name='f', grid=grid)
491+
492+
cf0 = ComponentAccess(f.indexify(), 0)
493+
cf1 = ComponentAccess(f.indexify(), 1)
494+
495+
assert ccode(cf0) == 'f[x][y][z].x'
496+
assert ccode(cf1) == 'f[x][y][z].y'
497+
498+
# Reconstruction
499+
cf2 = cf1.func(*cf1.args)
500+
assert cf2.index == cf1.index
501+
assert cf2 == cf1
502+
503+
486504
def test_canonical_ordering_of_weights():
487505
grid = Grid(shape=(3, 3, 3))
488506
x, y, z = grid.dimensions

0 commit comments

Comments
 (0)