Skip to content

Commit bde894a

Browse files
committed
dsl: Add Re and Im operators for taking real and imaginary parts of an expression
1 parent 34cdc53 commit bde894a

5 files changed

Lines changed: 172 additions & 5 deletions

File tree

devito/finite_differences/differentiable.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from devito.types.basic import AbstractFunction
2323

2424
__all__ = ['Differentiable', 'DiffDerivative', 'IndexDerivative', 'EvalDerivative',
25-
'Weights']
25+
'Weights', 'Re', 'Im']
2626

2727

2828
class Differentiable(sympy.Expr, Evaluable):
@@ -645,6 +645,39 @@ def __str__(self):
645645
__repr__ = __str__
646646

647647

648+
class ComplexPart(Differentiable, sympy.core.function.Application):
649+
"""Abstract class for `Re` or `Im` of an expression"""
650+
651+
def __new__(cls, *args, **kwargs):
652+
if len(args) != 1:
653+
raise ValueError(f"{cls.__name__} is constructed with exactly one arg;"
654+
f" {len(args)} were supplied.")
655+
656+
# Diffify any Add, Mul, etc which might be in the expression
657+
new_args = (diffify(args[0]),)
658+
659+
if not np.issubdtype(new_args[0].dtype, np.complexfloating):
660+
raise ValueError(f"{cls.__name__} requires a complex dtype,"
661+
f" not {new_args[0].dtype}.")
662+
663+
return super().__new__(cls, *new_args, **kwargs)
664+
665+
def __str__(self):
666+
return f"{self.__class__.__name__}({self.args[0]})"
667+
668+
__repr__ = __str__
669+
670+
671+
class Re(ComplexPart):
672+
"""Get the real part of an expression"""
673+
pass
674+
675+
676+
class Im(ComplexPart):
677+
"""Get the imaginary part of an expression"""
678+
pass
679+
680+
648681
class IndexSum(sympy.Expr, Evaluable):
649682

650683
"""

devito/ir/cgen/printer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,14 @@ def _print_SafeInv(self, expr):
217217
val = self._print(expr.val)
218218
return f'SAFEINV({val}, {base})'
219219

220+
def _print_Re(self, expr):
221+
"""Print an Re as an access into the second entry of a float array."""
222+
return f"Re({self._print(expr.args[0])})"
223+
224+
def _print_Im(self, expr):
225+
"""Print an Im as an access into the second entry of a float array."""
226+
return f"Im({self._print(expr.args[0])})"
227+
220228
def _print_Mod(self, expr):
221229
"""Print a Mod as a C-like %-based operation."""
222230
args = [f'({self._print(a)})' for a in expr.args]

devito/ir/equations/equation.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from devito.ir.support import (GuardFactor, Interval, IntervalGroup, IterationSpace,
99
Stencil, detect_io, detect_accesses)
1010
from devito.symbolics import IntDiv, limits_mapper, uxreplace
11-
from devito.tools import Pickable, Tag, frozendict
11+
from devito.tools import Pickable, Tag, frozendict, infer_dtype
1212
from devito.types import Eq, Inc, ReduceMax, ReduceMin, relational_min
1313

1414
__all__ = ['LoweredEq', 'ClusterizedEq', 'DummyEq', 'OpInc', 'OpMin', 'OpMax',
@@ -48,7 +48,12 @@ def directions(self):
4848

4949
@property
5050
def dtype(self):
51-
return self.lhs.dtype
51+
try:
52+
rhs_dtype = self.rhs.dtype
53+
except AttributeError:
54+
rhs_dtype = None
55+
56+
return infer_dtype({self.lhs.dtype, rhs_dtype} - {None})
5257

5358
@property
5459
def state(self):

devito/passes/iet/misc.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import sympy
66

77
from devito.finite_differences import Max, Min
8-
from devito.finite_differences.differentiable import SafeInv
8+
from devito.finite_differences.differentiable import SafeInv, Re, Im
99
from devito.logger import warning
1010
from devito.ir import (Any, Forward, DummyExpr, Iteration, List, Prodder,
1111
FindApplications, FindNodes, FindSymbols, Transformer,
@@ -245,6 +245,22 @@ def _(expr, langbb):
245245
f'(((a) < {eps}F || ({b}) < {eps}F) ? (0.0F) : ((1.0F) / (a)))'),), {}
246246

247247

248+
@_lower_macro_math.register(Re)
249+
def _(expr, langbb):
250+
return (('Re(x)',
251+
'(_Generic((x), '
252+
'float _Complex : (float *) &(x), '
253+
'double _Complex : (double *) &(x))[0])'),), {}
254+
255+
256+
@_lower_macro_math.register(Im)
257+
def _(expr, langbb):
258+
return (('Im(x)',
259+
'(_Generic((x), '
260+
'float _Complex : (float *) &(x), '
261+
'double _Complex : (double *) &(x))[1])'),), {}
262+
263+
248264
@iet_pass
249265
def minimize_symbols(iet):
250266
"""

tests/test_symbolics.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from sympy import Expr, Symbol
88
from devito import (Constant, Dimension, Grid, Function, solve, TimeFunction, Eq, # noqa
99
Operator, SubDimension, norm, Le, Ge, Gt, Lt, Abs, sin, cos,
10-
Min, Max)
10+
Min, Max, Re, Im)
1111
from devito.finite_differences.differentiable import SafeInv, Weights
1212
from devito.ir import Expression, FindNodes, ccode
1313
from devito.symbolics import (retrieve_functions, retrieve_indexed, evalrel, # noqa
@@ -818,3 +818,108 @@ def test_assumptions(self, op, expr, assumptions, expected):
818818
assumptions = eval(assumptions)
819819
expected = eval(expected)
820820
assert evalrel(op, eqn, assumptions) == expected
821+
822+
823+
class TestComplexParts:
824+
def setup_basic(self, dtype):
825+
grid = Grid(shape=(5,), extent=(4.,))
826+
f = Function(name='f', grid=grid, dtype=dtype)
827+
f.data_with_halo[:] = np.arange(7) + 1j*np.arange(7, 14)[::-1]
828+
829+
f_real = Function(name='f_real', grid=grid)
830+
f_imag = Function(name='f_imag', grid=grid)
831+
return f, f_real, f_imag
832+
833+
def test_printing(self):
834+
f, f_real, f_imag = self.setup_basic(np.complex64)
835+
836+
eq_re = Eq(f_real, Re(f))
837+
eq_im = Eq(f_imag, Im(f))
838+
839+
op = Operator([eq_re, eq_im])
840+
841+
assert ("#define Im(x) (_Generic((x), float _Complex : (float *) &(x), "
842+
"double _Complex : (double *) &(x))[1])") in str(op.ccode)
843+
assert ("#define Re(x) (_Generic((x), float _Complex : (float *) &(x), "
844+
"double _Complex : (double *) &(x))[0])") in str(op.ccode)
845+
846+
assert "f_real[x + 1] = Re(f[x + 1])" in str(op.ccode)
847+
assert "f_imag[x + 1] = Im(f[x + 1])" in str(op.ccode)
848+
849+
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
850+
def test_trivial(self, dtype):
851+
f, f_real, f_imag = self.setup_basic(dtype)
852+
853+
eq_re = Eq(f_real, Re(f+1))
854+
eq_im = Eq(f_imag, Im(f+1))
855+
856+
Operator([eq_re, eq_im])()
857+
858+
rcheck = np.array([2., 3., 4., 5., 6.])
859+
icheck = np.array([12., 11., 10., 9., 8.])
860+
assert np.all(np.isclose(f_real.data, rcheck))
861+
assert np.all(np.isclose(f_imag.data, icheck))
862+
863+
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
864+
def test_trivial_imag(self, dtype):
865+
f, f_real, f_imag = self.setup_basic(dtype)
866+
867+
eq_re = Eq(f_real, Re(f+1j))
868+
eq_im = Eq(f_imag, Im(f+1j))
869+
870+
Operator([eq_re, eq_im])()
871+
872+
rcheck = np.array([1., 2., 3., 4., 5.])
873+
icheck = np.array([13., 12., 11., 10., 9.])
874+
assert np.all(np.isclose(f_real.data, rcheck))
875+
assert np.all(np.isclose(f_imag.data, icheck))
876+
877+
def test_deriv(self):
878+
f, f_real, f_imag = self.setup_basic(np.complex64)
879+
880+
eq_re = Eq(f_real, Re(f.dx))
881+
eq_im = Eq(f_imag, Im(f.dx))
882+
883+
Operator([eq_re, eq_im])()
884+
885+
assert np.all(np.isclose(f_real.data, 1.))
886+
assert np.all(np.isclose(f_imag.data, -1.))
887+
888+
def test_outer_deriv(self):
889+
f, f_real, f_imag = self.setup_basic(np.complex64)
890+
891+
eq_re = Eq(f_real, Re(f).dx)
892+
eq_im = Eq(f_imag, Im(f).dx)
893+
894+
Operator([eq_re, eq_im])()
895+
896+
assert np.all(np.isclose(f_real.data, 1.))
897+
assert np.all(np.isclose(f_imag.data, -1.))
898+
899+
def test_mul(self):
900+
grid = Grid(shape=(5,))
901+
902+
f = Function(name='f', grid=grid, dtype=np.complex64)
903+
g = Function(name='g', grid=grid)
904+
h = Function(name='h', grid=grid, dtype=np.complex64)
905+
f.data[:] = 1 + 1j
906+
g.data[:] = 2
907+
h.data[:] = 2j
908+
909+
fg_re = Function(name='fg_re', grid=grid)
910+
fg_im = Function(name='fg_im', grid=grid)
911+
fh_re = Function(name='fh_re', grid=grid)
912+
fh_im = Function(name='fh_im', grid=grid)
913+
914+
eq_fg_re = Eq(fg_re, Re(f*g))
915+
eq_fg_im = Eq(fg_im, Im(f*g))
916+
eq_fh_re = Eq(fh_re, Re(f*h))
917+
eq_fh_im = Eq(fh_im, Im(f*h))
918+
919+
Operator([eq_fg_re, eq_fg_im, eq_fh_re, eq_fh_im])()
920+
921+
assert np.all(np.isclose(fg_re.data, 2.))
922+
assert np.all(np.isclose(fg_im.data, 2.))
923+
924+
assert np.all(np.isclose(fh_re.data, -2.))
925+
assert np.all(np.isclose(fh_im.data, 2.))

0 commit comments

Comments
 (0)