Skip to content

Commit cb43942

Browse files
committed
compiler: Postpone weights unevaluation
1 parent 58f1c70 commit cb43942

3 files changed

Lines changed: 6 additions & 26 deletions

File tree

devito/finite_differences/differentiable.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -746,8 +746,8 @@ def __init_finalize__(self, *args, **kwargs):
746746
assert isinstance(weights, (list, tuple, np.ndarray))
747747

748748
# Normalize `weights`
749-
from devito.symbolics import pow_to_mul, unevaluate # noqa, sigh
750-
weights = tuple(unevaluate(pow_to_mul(sympy.sympify(i))) for i in weights)
749+
from devito.symbolics import pow_to_mul # noqa, sigh
750+
weights = tuple(pow_to_mul(sympy.sympify(i)) for i in weights)
751751

752752
kwargs['scope'] = kwargs.get('scope', 'stack')
753753
kwargs['initvalue'] = weights

devito/passes/iet/definitions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from devito.passes.iet.engine import iet_pass
1717
from devito.passes.iet.langbase import LangBB
1818
from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer,
19-
SizeOf, VOID, pow_to_mul)
19+
SizeOf, VOID, pow_to_mul, unevaluate)
2020
from devito.tools import as_mapper, as_list, as_tuple, filter_sorted, flatten
2121
from devito.types import (Array, ComponentAccess, CustomDimension, DeviceMap,
2222
DeviceRM, Eq, Symbol)
@@ -119,7 +119,7 @@ def _alloc_array_on_global_mem(self, site, obj, storage):
119119

120120
# Create input array
121121
name = '%s_init' % obj.name
122-
initvalue = np.array([pow_to_mul(i) for i in obj.initvalue])
122+
initvalue = np.array([unevaluate(pow_to_mul(i)) for i in obj.initvalue])
123123
src = Array(name=name, dtype=obj.dtype, dimensions=obj.dimensions,
124124
space='host', scope='stack', initvalue=initvalue)
125125

tests/test_pickle.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import ctypes
22
import pickle as pickle0
3-
import cloudpickle as pickle1
43

4+
import cloudpickle as pickle1
55
import pytest
66
import numpy as np
77
from sympy import Symbol
@@ -12,15 +12,14 @@
1212
PrecomputedSparseTimeFunction, SubDomain)
1313
from devito.ir import Backward, Forward, GuardFactor, GuardBound, GuardBoundNext
1414
from devito.data import LEFT, OWNED
15-
from devito.finite_differences.differentiable import Weights
1615
from devito.finite_differences.tools import direct, transpose, left, right, centered
1716
from devito.mpi.halo_scheme import Halo
1817
from devito.mpi.routines import (MPIStatusObject, MPIMsgEnriched, MPIRequestObject,
1918
MPIRegion)
2019
from devito.types import (Array, CustomDimension, Symbol as dSymbol, Scalar,
2120
PointerArray, Lock, PThreadArray, SharedData, Timer,
2221
DeviceID, NPThreads, ThreadID, TempFunction, Indirection,
23-
FIndexed, ComponentAccess, StencilDimension)
22+
FIndexed, ComponentAccess)
2423
from devito.types.basic import BoundSymbol, AbstractSymbol
2524
from devito.tools import EnrichedTuple
2625
from devito.symbolics import (IntDiv, ListInitializer, FieldFromPointer,
@@ -431,25 +430,6 @@ def test_component_access(self, pickle):
431430
assert new_ca.index == 1
432431
assert new_ca.function.name == f.name
433432

434-
def test_weights_to_array(self, pickle):
435-
grid = Grid(shape=(3, 3, 3))
436-
x, y, z = grid.dimensions
437-
h_x = x.spacing
438-
439-
i = StencilDimension('i0', 0, 2)
440-
w = Weights(name='w0', dimensions=i,
441-
initvalue=[1/(h_x**2), 2/(h_x**2), 3/(h_x**2)])
442-
a = Array(name='w0', dimensions=w.dimensions, initvalue=w.initvalue,
443-
scope='stack')
444-
445-
pkl_a = pickle.dumps(a)
446-
new_a = pickle.loads(pkl_a)
447-
448-
# Weights optimizes `initvalue` by turning pows into muls. This test checks
449-
# that the optimization is correctly carried over to the pickled object
450-
# (in practice, the optimized expressions must have been frozen)
451-
assert a.initvalue == new_a.initvalue
452-
453433
def test_symbolics(self, pickle):
454434
a = Symbol('a')
455435

0 commit comments

Comments
 (0)