Skip to content

Commit 6b904d9

Browse files
committed
compiler: Fix Weights reconstruction
1 parent ad99102 commit 6b904d9

4 files changed

Lines changed: 44 additions & 29 deletions

File tree

devito/finite_differences/differentiable.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -746,8 +746,8 @@ def __init_finalize__(self, *args, **kwargs):
746746
assert isinstance(weights, (list, tuple, np.ndarray))
747747

748748
# Normalize `weights`
749-
from devito.symbolics import pow_to_mul # noqa, sigh
750-
weights = tuple(pow_to_mul(sympy.sympify(i)) for i in weights)
749+
from devito.symbolics import pow_to_mul, unevaluate # noqa, sigh
750+
weights = tuple(unevaluate(pow_to_mul(sympy.sympify(i))) for i in weights)
751751

752752
kwargs['scope'] = kwargs.get('scope', 'stack')
753753
kwargs['initvalue'] = weights
Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
1-
import sympy
2-
31
from devito.ir import cluster_pass
4-
from devito.symbolics import reuse_if_untouched, q_leaf
5-
from devito.symbolics.unevaluation import Add, Mul, Pow
2+
from devito.symbolics import unevaluate as _unevaluate
63

74
__all__ = ['unevaluate']
85

@@ -12,22 +9,3 @@ def unevaluate(cluster):
129
exprs = [_unevaluate(e) for e in cluster.exprs]
1310

1411
return cluster.rebuild(exprs=exprs)
15-
16-
17-
mapper = {
18-
sympy.Add: Add,
19-
sympy.Mul: Mul,
20-
sympy.Pow: Pow
21-
}
22-
23-
24-
def _unevaluate(expr):
25-
if q_leaf(expr):
26-
return expr
27-
28-
args = [_unevaluate(a) for a in expr.args]
29-
30-
try:
31-
return mapper[expr.func](*args)
32-
except KeyError:
33-
return reuse_if_untouched(expr, args)

devito/symbolics/manipulation.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
from devito.symbolics.extended_sympy import DefFunction, rfunc
1414
from devito.symbolics.queries import q_leaf
1515
from devito.symbolics.search import retrieve_indexed, retrieve_functions
16-
from devito.symbolics.unevaluation import Mul as UMul
16+
from devito.symbolics.unevaluation import (
17+
Add as UnevalAdd, Mul as UnevalMul, Pow as UnevalPow
18+
)
1719
from devito.tools import as_list, as_tuple, flatten, split, transitive_closure
1820
from devito.types.basic import Basic, Indexed
1921
from devito.types.array import ComponentAccess
@@ -22,7 +24,7 @@
2224

2325
__all__ = ['xreplace_indices', 'pow_to_mul', 'indexify', 'subs_op_args',
2426
'normalize_args', 'uxreplace', 'Uxmapper', 'subs_if_composite',
25-
'reuse_if_untouched', 'evalrel', 'flatten_args']
27+
'reuse_if_untouched', 'evalrel', 'flatten_args', 'unevaluate']
2628

2729

2830
def uxreplace(expr, rule):
@@ -338,7 +340,7 @@ def pow_to_mul(expr):
338340
# but at least we traverse the base looking for other Pows
339341
return expr.func(pow_to_mul(base), exp, evaluate=False)
340342
elif exp > 0:
341-
return UMul(*[pow_to_mul(base)]*int(exp), evaluate=False)
343+
return UnevalMul(*[pow_to_mul(base)]*int(exp), evaluate=False)
342344
elif exp < 0:
343345
# Reciprocal powers become inverse of the negative power
344346
# for example Pow(expr, -2) becomes Pow(expr * expr, -1)
@@ -502,3 +504,18 @@ def evalrel(func=min, input=None, assumptions=None):
502504
except TypeError:
503505
pass
504506
return rfunc(func, *input)
507+
508+
509+
uneval_mapper = {Add: UnevalAdd, Mul: UnevalMul, Pow: UnevalPow}
510+
511+
512+
def unevaluate(expr):
513+
if q_leaf(expr):
514+
return expr
515+
516+
args = [unevaluate(a) for a in expr.args]
517+
518+
try:
519+
return uneval_mapper[expr.func](*args)
520+
except KeyError:
521+
return reuse_if_untouched(expr, args)

tests/test_pickle.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212
PrecomputedSparseTimeFunction, SubDomain)
1313
from devito.ir import Backward, Forward, GuardFactor, GuardBound, GuardBoundNext
1414
from devito.data import LEFT, OWNED
15+
from devito.finite_differences.differentiable import Weights
1516
from devito.finite_differences.tools import direct, transpose, left, right, centered
1617
from devito.mpi.halo_scheme import Halo
1718
from devito.mpi.routines import (MPIStatusObject, MPIMsgEnriched, MPIRequestObject,
1819
MPIRegion)
1920
from devito.types import (Array, CustomDimension, Symbol as dSymbol, Scalar,
2021
PointerArray, Lock, PThreadArray, SharedData, Timer,
2122
DeviceID, NPThreads, ThreadID, TempFunction, Indirection,
22-
FIndexed)
23+
FIndexed, StencilDimension)
2324
from devito.types.basic import BoundSymbol, AbstractSymbol
2425
from devito.tools import EnrichedTuple
2526
from devito.symbolics import (IntDiv, ListInitializer, FieldFromPointer,
@@ -416,6 +417,25 @@ def test_findexed(self, pickle):
416417
assert new_fi.indices == (x+1, y, z-2)
417418
assert new_fi.strides_map == fi.strides_map
418419

420+
def test_weights_to_array(self, pickle):
421+
grid = Grid(shape=(3, 3, 3))
422+
x, y, z = grid.dimensions
423+
h_x = x.spacing
424+
425+
i = StencilDimension('i0', 0, 2)
426+
w = Weights(name='w0', dimensions=i,
427+
initvalue=[1/(h_x**2), 2/(h_x**2), 3/(h_x**2)])
428+
a = Array(name='w0', dimensions=w.dimensions, initvalue=w.initvalue,
429+
scope='stack')
430+
431+
pkl_a = pickle.dumps(a)
432+
new_a = pickle.loads(pkl_a)
433+
434+
# Weights optimizes `initvalue` by turning pows into muls. This test checks
435+
# that the optimization is correctly carried over to the pickled object
436+
# (in practice, the optimized expressions must have been frozen)
437+
assert a.initvalue == new_a.initvalue
438+
419439
def test_symbolics(self, pickle):
420440
a = Symbol('a')
421441

0 commit comments

Comments
 (0)