|
2 | 2 |
|
3 | 3 | import sympy |
4 | 4 | import pytest |
| 5 | +import numpy as np |
5 | 6 |
|
6 | 7 | from devito import Function, Grid, Differentiable, NODE |
7 | 8 | from devito.finite_differences.differentiable import (Add, Mul, Pow, diffify, |
@@ -96,28 +97,51 @@ def sp_diff(a, b): |
96 | 97 |
|
97 | 98 |
|
98 | 99 | @pytest.mark.parametrize('ndim', [1, 2, 3]) |
99 | | -def test_avg_mode(ndim): |
| 100 | +@pytest.mark.parametrize('io', [None, 2, 4]) |
| 101 | +def test_avg_mode(ndim, io): |
100 | 102 | grid = Grid([11]*ndim) |
101 | | - v = Function(name='v', grid=grid, staggered=grid.dimensions) |
102 | | - a0 = Function(name="a0", grid=grid) |
103 | | - a = Function(name="a", grid=grid, parameter=True) |
104 | | - b = Function(name="b", grid=grid, parameter=True, avg_mode='harmonic') |
| 103 | + v = Function(name='v', grid=grid, staggered=grid.dimensions, space_order=4) |
| 104 | + kw = {'space_order': 4} |
| 105 | + if io is not None: |
| 106 | + kw['interp_order'] = io |
| 107 | + else: |
| 108 | + io = 2 # Default value |
| 109 | + |
| 110 | + with pytest.raises(ValueError): |
| 111 | + # interp_order > space_order |
| 112 | + Function(name="a", grid=grid, parameter=True, interp_order=8, space_order=4) |
| 113 | + with pytest.raises(ValueError): |
| 114 | + # interp_order < 1 |
| 115 | + Function(name="a", grid=grid, parameter=True, interp_order=0, space_order=4) |
| 116 | + with pytest.raises(TypeError): |
| 117 | + # interp_order not int |
| 118 | + Function(name="a", grid=grid, parameter=True, interp_order=2.5, space_order=4) |
| 119 | + |
| 120 | + a0 = Function(name="a0", grid=grid, **kw) |
| 121 | + a = Function(name="a", grid=grid, parameter=True, **kw) |
| 122 | + b = Function(name="b", grid=grid, parameter=True, avg_mode='harmonic', **kw) |
105 | 123 |
|
106 | 124 | a0_avg = a0._eval_at(v) |
107 | | - a_avg = a._eval_at(v).evaluate |
108 | | - b_avg = b._eval_at(v).evaluate |
| 125 | + a_avg = a._eval_at(v).evaluate.simplify() |
| 126 | + b_avg = b._eval_at(v).evaluate.simplify() |
109 | 127 |
|
110 | 128 | assert a0_avg == a0 |
111 | 129 |
|
112 | 130 | # Indices around the point at the center of a cell |
113 | | - all_shift = tuple(product(*[[0, 1] for _ in range(ndim)])) |
| 131 | + idx = list(range(-io//2 + 1, io//2 + 1)) |
| 132 | + all_shift = tuple(product(*[idx for _ in range(ndim)])) |
| 133 | + coeffs = {2: [0.5, 0.5], 4: [-1/16, 9/16, 9/16, -1/16]}[io] |
| 134 | + vars = ['i', 'j', 'k'][:ndim] |
| 135 | + rule = ','.join(vars) + '->' + ''.join(vars) |
| 136 | + ndcoeffs = np.einsum(rule, *([coeffs]*ndim)) |
114 | 137 | args = [{d: d + i * d.spacing for d, i in zip(grid.dimensions, s)} for s in all_shift] |
115 | 138 |
|
116 | 139 | # Default is arithmetic average |
117 | | - assert sympy.simplify(a_avg - 0.5**ndim * sum(a.subs(arg) for arg in args)) == 0 |
| 140 | + expected = sum(c * a.subs(arg) for c, arg in zip(ndcoeffs.flatten(), args)) |
| 141 | + assert sympy.simplify(a_avg - expected) == 0 |
118 | 142 |
|
119 | 143 | # Harmonic average, h(a[.5]) = 1/(.5/a[0] + .5/a[1]) |
120 | | - expected = 1/(0.5**ndim * sum(1/b.subs(arg) for arg in args)) |
121 | | - assert sympy.simplify(1/b_avg.args[0] - expected) == 0 |
| 144 | + expected = (sum(c / b.subs(arg) for c, arg in zip(ndcoeffs.flatten(), args))) |
| 145 | + assert sympy.simplify(b_avg.args[0] - expected) == 0 |
122 | 146 | assert isinstance(b_avg, SafeInv) |
123 | 147 | assert b_avg.base == b |
0 commit comments