|
4 | 4 | import pytest |
5 | 5 | import numpy as np |
6 | 6 |
|
7 | | -from sympy import Expr, Symbol |
| 7 | +from sympy import Expr, Number, Symbol |
8 | 8 | from devito import (Constant, Dimension, Grid, Function, solve, TimeFunction, Eq, # noqa |
9 | 9 | Operator, SubDimension, norm, Le, Ge, Gt, Lt, Abs, sin, cos, |
10 | 10 | Min, Max) |
|
13 | 13 | from devito.symbolics import (retrieve_functions, retrieve_indexed, evalrel, # noqa |
14 | 14 | CallFromPointer, Cast, DefFunction, FieldFromPointer, |
15 | 15 | INT, FieldFromComposite, IntDiv, Namespace, Rvalue, |
16 | | - ReservedWord, ListInitializer, uxreplace, |
| 16 | + ReservedWord, ListInitializer, uxreplace, pow_to_mul, |
17 | 17 | retrieve_derivatives, BaseCast) |
18 | 18 | from devito.tools import as_tuple |
19 | 19 | from devito.types import (Array, Bundle, FIndexed, LocalObject, Object, |
@@ -589,51 +589,67 @@ def test_solve_time(): |
589 | 589 | assert sympy.simplify(sympy.expand(sol - (-dt**2*u.dx/m + 2.0*u - u.backward))) == 0 |
590 | 590 |
|
591 | 591 |
|
592 | | -@pytest.mark.parametrize('expr,subs,expected', [ |
593 | | - ('f', '{f: g}', 'g'), |
594 | | - ('f[x, y+1]', '{f.indexed: g.indexed}', 'g[x, y+1]'), |
595 | | - ('cos(f)', '{cos: sin}', 'sin(f)'), |
596 | | - ('cos(f + sin(g))', '{cos: sin, sin: cos}', 'sin(f + cos(g))'), |
597 | | - ('FIndexed(f.indexed, x, y)', '{x: 0}', 'FIndexed(f.indexed, 0, y)'), |
598 | | -]) |
599 | | -def test_uxreplace(expr, subs, expected): |
600 | | - grid = Grid(shape=(4, 4)) |
601 | | - x, y = grid.dimensions # noqa |
| 592 | +class TestUxreplace: |
602 | 593 |
|
603 | | - f = Function(name='f', grid=grid) # noqa |
604 | | - g = Function(name='g', grid=grid) # noqa |
| 594 | + @pytest.mark.parametrize('expr,subs,expected', [ |
| 595 | + ('f', '{f: g}', 'g'), |
| 596 | + ('f[x, y+1]', '{f.indexed: g.indexed}', 'g[x, y+1]'), |
| 597 | + ('cos(f)', '{cos: sin}', 'sin(f)'), |
| 598 | + ('cos(f + sin(g))', '{cos: sin, sin: cos}', 'sin(f + cos(g))'), |
| 599 | + ('FIndexed(f.indexed, x, y)', '{x: 0}', 'FIndexed(f.indexed, 0, y)'), |
| 600 | + ]) |
| 601 | + def test_expressions(self, expr, subs, expected): |
| 602 | + grid = Grid(shape=(4, 4)) |
| 603 | + x, y = grid.dimensions # noqa |
605 | 604 |
|
606 | | - assert uxreplace(eval(expr), eval(subs)) == eval(expected) |
| 605 | + f = Function(name='f', grid=grid) # noqa |
| 606 | + g = Function(name='g', grid=grid) # noqa |
607 | 607 |
|
| 608 | + assert uxreplace(eval(expr), eval(subs)) == eval(expected) |
608 | 609 |
|
609 | | -def test_uxreplace_custom_reconstructable(): |
| 610 | + def test_custom_reconstructable(self): |
610 | 611 |
|
611 | | - class MyDefFunction(DefFunction): |
612 | | - __rargs__ = ('name', 'arguments') |
613 | | - __rkwargs__ = ('p0', 'p1', 'p2') |
| 612 | + class MyDefFunction(DefFunction): |
| 613 | + __rargs__ = ('name', 'arguments') |
| 614 | + __rkwargs__ = ('p0', 'p1', 'p2') |
614 | 615 |
|
615 | | - def __new__(cls, name=None, arguments=None, p0=None, p1=None, p2=None): |
616 | | - obj = super().__new__(cls, name=name, arguments=arguments) |
617 | | - obj.p0 = p0 |
618 | | - obj.p1 = as_tuple(p1) |
619 | | - obj.p2 = p2 |
620 | | - return obj |
| 616 | + def __new__(cls, name=None, arguments=None, p0=None, p1=None, p2=None): |
| 617 | + obj = super().__new__(cls, name=name, arguments=arguments) |
| 618 | + obj.p0 = p0 |
| 619 | + obj.p1 = as_tuple(p1) |
| 620 | + obj.p2 = p2 |
| 621 | + return obj |
621 | 622 |
|
622 | | - grid = Grid(shape=(4, 4)) |
| 623 | + grid = Grid(shape=(4, 4)) |
623 | 624 |
|
624 | | - f = Function(name='f', grid=grid) |
625 | | - g = Function(name='g', grid=grid) |
| 625 | + f = Function(name='f', grid=grid) |
| 626 | + g = Function(name='g', grid=grid) |
| 627 | + |
| 628 | + func = MyDefFunction(name='foo', arguments=f.indexify(), |
| 629 | + p0=f, p1=f, p2='bar') |
| 630 | + |
| 631 | + mapper = {f: g, f.indexify(): g.indexify()} |
| 632 | + func1 = uxreplace(func, mapper) |
| 633 | + |
| 634 | + assert func1.arguments == (g.indexify(),) |
| 635 | + assert func1.p0 is g |
| 636 | + assert func1.p1 == (g,) |
| 637 | + assert func1.p2 == 'bar' |
| 638 | + |
| 639 | + def test_reduce_to_number(self): |
| 640 | + grid = Grid(shape=(4, 4)) |
| 641 | + x, _ = grid.dimensions |
| 642 | + h_x = x.spacing |
626 | 643 |
|
627 | | - func = MyDefFunction(name='foo', arguments=f.indexify(), |
628 | | - p0=f, p1=f, p2='bar') |
| 644 | + # Emulate lowered coefficient |
| 645 | + w = -0.0354212/(h_x*h_x) |
| 646 | + w_lowered = pow_to_mul(w) |
629 | 647 |
|
630 | | - mapper = {f: g, f.indexify(): g.indexify()} |
631 | | - func1 = uxreplace(func, mapper) |
| 648 | + w_sub = uxreplace(w_lowered, {h_x: Number(3)}) |
632 | 649 |
|
633 | | - assert func1.arguments == (g.indexify(),) |
634 | | - assert func1.p0 is g |
635 | | - assert func1.p1 == (g,) |
636 | | - assert func1.p2 == 'bar' |
| 650 | + assert np.isclose(w_sub, -0.003935689) |
| 651 | + assert not w_sub.is_Mul |
| 652 | + assert w_sub.is_Number |
637 | 653 |
|
638 | 654 |
|
639 | 655 | def test_minmax(): |
|
0 commit comments