Skip to content

Commit 92673a7

Browse files
authored
Merge pull request #2571 from devitocodes/hotfix-weight-eval
compiler: Fix uxreplace behaviour with numeric expressions
2 parents aca2d83 + b5dc57b commit 92673a7

2 files changed

Lines changed: 58 additions & 38 deletions

File tree

devito/symbolics/manipulation.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from devito.symbolics.queries import q_leaf
1515
from devito.symbolics.search import retrieve_indexed, retrieve_functions
1616
from devito.symbolics.unevaluation import (
17-
Add as UnevalAdd, Mul as UnevalMul, Pow as UnevalPow
17+
Add as UnevalAdd, Mul as UnevalMul, Pow as UnevalPow, UnevaluableMixin
1818
)
1919
from devito.tools import as_list, as_tuple, flatten, split, transitive_closure
2020
from devito.types.basic import Basic, Indexed
@@ -300,7 +300,11 @@ def _eval_numbers(expr, args):
300300
"""
301301
numbers, others = split(args, lambda i: i.is_Number)
302302
if len(numbers) > 1:
303-
args[:] = [expr.func(*numbers)] + others
303+
if isinstance(expr, UnevaluableMixin):
304+
cls = expr.func.__base__
305+
else:
306+
cls = expr.func
307+
args[:] = [cls(*numbers)] + others
304308

305309

306310
def flatten_args(args, op, ignore=None):

tests/test_symbolics.py

Lines changed: 52 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55
import numpy as np
66

7-
from sympy import Expr, Symbol
7+
from sympy import Expr, Number, Symbol
88
from devito import (Constant, Dimension, Grid, Function, solve, TimeFunction, Eq, # noqa
99
Operator, SubDimension, norm, Le, Ge, Gt, Lt, Abs, sin, cos,
1010
Min, Max)
@@ -13,7 +13,7 @@
1313
from devito.symbolics import (retrieve_functions, retrieve_indexed, evalrel, # noqa
1414
CallFromPointer, Cast, DefFunction, FieldFromPointer,
1515
INT, FieldFromComposite, IntDiv, Namespace, Rvalue,
16-
ReservedWord, ListInitializer, uxreplace,
16+
ReservedWord, ListInitializer, uxreplace, pow_to_mul,
1717
retrieve_derivatives, BaseCast)
1818
from devito.tools import as_tuple
1919
from devito.types import (Array, Bundle, FIndexed, LocalObject, Object,
@@ -589,51 +589,67 @@ def test_solve_time():
589589
assert sympy.simplify(sympy.expand(sol - (-dt**2*u.dx/m + 2.0*u - u.backward))) == 0
590590

591591

592-
@pytest.mark.parametrize('expr,subs,expected', [
593-
('f', '{f: g}', 'g'),
594-
('f[x, y+1]', '{f.indexed: g.indexed}', 'g[x, y+1]'),
595-
('cos(f)', '{cos: sin}', 'sin(f)'),
596-
('cos(f + sin(g))', '{cos: sin, sin: cos}', 'sin(f + cos(g))'),
597-
('FIndexed(f.indexed, x, y)', '{x: 0}', 'FIndexed(f.indexed, 0, y)'),
598-
])
599-
def test_uxreplace(expr, subs, expected):
600-
grid = Grid(shape=(4, 4))
601-
x, y = grid.dimensions # noqa
592+
class TestUxreplace:
602593

603-
f = Function(name='f', grid=grid) # noqa
604-
g = Function(name='g', grid=grid) # noqa
594+
@pytest.mark.parametrize('expr,subs,expected', [
595+
('f', '{f: g}', 'g'),
596+
('f[x, y+1]', '{f.indexed: g.indexed}', 'g[x, y+1]'),
597+
('cos(f)', '{cos: sin}', 'sin(f)'),
598+
('cos(f + sin(g))', '{cos: sin, sin: cos}', 'sin(f + cos(g))'),
599+
('FIndexed(f.indexed, x, y)', '{x: 0}', 'FIndexed(f.indexed, 0, y)'),
600+
])
601+
def test_expressions(self, expr, subs, expected):
602+
grid = Grid(shape=(4, 4))
603+
x, y = grid.dimensions # noqa
605604

606-
assert uxreplace(eval(expr), eval(subs)) == eval(expected)
605+
f = Function(name='f', grid=grid) # noqa
606+
g = Function(name='g', grid=grid) # noqa
607607

608+
assert uxreplace(eval(expr), eval(subs)) == eval(expected)
608609

609-
def test_uxreplace_custom_reconstructable():
610+
def test_custom_reconstructable(self):
610611

611-
class MyDefFunction(DefFunction):
612-
__rargs__ = ('name', 'arguments')
613-
__rkwargs__ = ('p0', 'p1', 'p2')
612+
class MyDefFunction(DefFunction):
613+
__rargs__ = ('name', 'arguments')
614+
__rkwargs__ = ('p0', 'p1', 'p2')
614615

615-
def __new__(cls, name=None, arguments=None, p0=None, p1=None, p2=None):
616-
obj = super().__new__(cls, name=name, arguments=arguments)
617-
obj.p0 = p0
618-
obj.p1 = as_tuple(p1)
619-
obj.p2 = p2
620-
return obj
616+
def __new__(cls, name=None, arguments=None, p0=None, p1=None, p2=None):
617+
obj = super().__new__(cls, name=name, arguments=arguments)
618+
obj.p0 = p0
619+
obj.p1 = as_tuple(p1)
620+
obj.p2 = p2
621+
return obj
621622

622-
grid = Grid(shape=(4, 4))
623+
grid = Grid(shape=(4, 4))
623624

624-
f = Function(name='f', grid=grid)
625-
g = Function(name='g', grid=grid)
625+
f = Function(name='f', grid=grid)
626+
g = Function(name='g', grid=grid)
627+
628+
func = MyDefFunction(name='foo', arguments=f.indexify(),
629+
p0=f, p1=f, p2='bar')
630+
631+
mapper = {f: g, f.indexify(): g.indexify()}
632+
func1 = uxreplace(func, mapper)
633+
634+
assert func1.arguments == (g.indexify(),)
635+
assert func1.p0 is g
636+
assert func1.p1 == (g,)
637+
assert func1.p2 == 'bar'
638+
639+
def test_reduce_to_number(self):
640+
grid = Grid(shape=(4, 4))
641+
x, _ = grid.dimensions
642+
h_x = x.spacing
626643

627-
func = MyDefFunction(name='foo', arguments=f.indexify(),
628-
p0=f, p1=f, p2='bar')
644+
# Emulate lowered coefficient
645+
w = -0.0354212/(h_x*h_x)
646+
w_lowered = pow_to_mul(w)
629647

630-
mapper = {f: g, f.indexify(): g.indexify()}
631-
func1 = uxreplace(func, mapper)
648+
w_sub = uxreplace(w_lowered, {h_x: Number(3)})
632649

633-
assert func1.arguments == (g.indexify(),)
634-
assert func1.p0 is g
635-
assert func1.p1 == (g,)
636-
assert func1.p2 == 'bar'
650+
assert np.isclose(w_sub, -0.003935689)
651+
assert not w_sub.is_Mul
652+
assert w_sub.is_Number
637653

638654

639655
def test_minmax():

0 commit comments

Comments
 (0)