Skip to content

Commit b5dc57b

Browse files
committed
compiler: Fix uxreplace's eval of subbed numeric exprs
1 parent 8ba5259 commit b5dc57b

2 files changed

Lines changed: 23 additions & 4 deletions

File tree

devito/symbolics/manipulation.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from devito.symbolics.queries import q_leaf
1515
from devito.symbolics.search import retrieve_indexed, retrieve_functions
1616
from devito.symbolics.unevaluation import (
17-
Add as UnevalAdd, Mul as UnevalMul, Pow as UnevalPow
17+
Add as UnevalAdd, Mul as UnevalMul, Pow as UnevalPow, UnevaluableMixin
1818
)
1919
from devito.tools import as_list, as_tuple, flatten, split, transitive_closure
2020
from devito.types.basic import Basic, Indexed
@@ -300,7 +300,11 @@ def _eval_numbers(expr, args):
300300
"""
301301
numbers, others = split(args, lambda i: i.is_Number)
302302
if len(numbers) > 1:
303-
args[:] = [expr.func(*numbers)] + others
303+
if isinstance(expr, UnevaluableMixin):
304+
cls = expr.func.__base__
305+
else:
306+
cls = expr.func
307+
args[:] = [cls(*numbers)] + others
304308

305309

306310
def flatten_args(args, op, ignore=None):

tests/test_symbolics.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55
import numpy as np
66

7-
from sympy import Expr, Symbol
7+
from sympy import Expr, Number, Symbol
88
from devito import (Constant, Dimension, Grid, Function, solve, TimeFunction, Eq, # noqa
99
Operator, SubDimension, norm, Le, Ge, Gt, Lt, Abs, sin, cos,
1010
Min, Max)
@@ -13,7 +13,7 @@
1313
from devito.symbolics import (retrieve_functions, retrieve_indexed, evalrel, # noqa
1414
CallFromPointer, Cast, DefFunction, FieldFromPointer,
1515
INT, FieldFromComposite, IntDiv, Namespace, Rvalue,
16-
ReservedWord, ListInitializer, uxreplace,
16+
ReservedWord, ListInitializer, uxreplace, pow_to_mul,
1717
retrieve_derivatives, BaseCast)
1818
from devito.tools import as_tuple
1919
from devito.types import (Array, Bundle, FIndexed, LocalObject, Object,
@@ -636,6 +636,21 @@ def __new__(cls, name=None, arguments=None, p0=None, p1=None, p2=None):
636636
assert func1.p1 == (g,)
637637
assert func1.p2 == 'bar'
638638

639+
def test_reduce_to_number(self):
640+
grid = Grid(shape=(4, 4))
641+
x, _ = grid.dimensions
642+
h_x = x.spacing
643+
644+
# Emulate lowered coefficient
645+
w = -0.0354212/(h_x*h_x)
646+
w_lowered = pow_to_mul(w)
647+
648+
w_sub = uxreplace(w_lowered, {h_x: Number(3)})
649+
650+
assert np.isclose(w_sub, -0.003935689)
651+
assert not w_sub.is_Mul
652+
assert w_sub.is_Number
653+
639654

640655
def test_minmax():
641656
grid = Grid(shape=(5, 5))

0 commit comments

Comments
 (0)