Skip to content

Commit 9ee60cc

Browse files
EdCauntmloubout
authored andcommitted
dsl: Add Re and Im operators for taking real and imaginary parts of an expression
1 parent a4d4bd4 commit 9ee60cc

5 files changed

Lines changed: 179 additions & 4 deletions

File tree

devito/finite_differences/differentiable.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from devito.types.basic import AbstractFunction
2323

2424
__all__ = ['Differentiable', 'DiffDerivative', 'IndexDerivative', 'EvalDerivative',
25-
'Weights']
25+
'Weights', 'Re', 'Im']
2626

2727

2828
class Differentiable(sympy.Expr, Evaluable):
@@ -644,6 +644,39 @@ def __str__(self):
644644
__repr__ = __str__
645645

646646

647+
class ComplexPart(Differentiable, sympy.core.function.Application):
648+
"""Abstract class for `Re` or `Im` of an expression"""
649+
650+
def __new__(cls, *args, **kwargs):
651+
if len(args) != 1:
652+
raise ValueError(f"{cls.__name__} is constructed with exactly one arg;"
653+
f" {len(args)} were supplied.")
654+
655+
# Diffify any Add, Mul, etc which might be in the expression
656+
new_args = (diffify(args[0]),)
657+
658+
if not np.issubdtype(new_args[0].dtype, np.complexfloating):
659+
raise ValueError(f"{cls.__name__} requires a complex dtype,"
660+
f" not {new_args[0].dtype.__name__}.")
661+
662+
return super().__new__(cls, *new_args, **kwargs)
663+
664+
def __str__(self):
665+
return f"{self.__class__.__name__}({self.args[0]})"
666+
667+
__repr__ = __str__
668+
669+
670+
class Re(ComplexPart):
671+
"""Get the real part of an expression"""
672+
pass
673+
674+
675+
class Im(ComplexPart):
676+
"""Get the imaginary part of an expression"""
677+
pass
678+
679+
647680
class IndexSum(sympy.Expr, Evaluable):
648681

649682
"""

devito/ir/equations/equation.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from devito.ir.support import (GuardFactor, Interval, IntervalGroup, IterationSpace,
99
Stencil, detect_io, detect_accesses)
1010
from devito.symbolics import IntDiv, limits_mapper, uxreplace
11-
from devito.tools import Pickable, Tag, frozendict
11+
from devito.tools import Pickable, Tag, frozendict, infer_dtype
1212
from devito.types import Eq, Inc, ReduceMax, ReduceMin, relational_min
1313

1414
__all__ = ['LoweredEq', 'ClusterizedEq', 'DummyEq', 'OpInc', 'OpMin', 'OpMax',
@@ -48,7 +48,12 @@ def directions(self):
4848

4949
@property
5050
def dtype(self):
51-
return self.lhs.dtype
51+
try:
52+
rhs_dtype = self.rhs.dtype
53+
except AttributeError:
54+
rhs_dtype = None
55+
56+
return infer_dtype({self.lhs.dtype, rhs_dtype} - {None})
5257

5358
@property
5459
def state(self):

devito/passes/iet/languages/C.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,13 @@ def _print_ListInitializer(self, expr):
6767
return f'({tstr}[]){li}'
6868
else:
6969
return li
70+
71+
def _print_Re(self, expr):
72+
"""Print an Re as an access into the second entry of a float array."""
73+
return (f'{self.func_prefix(expr)}real{self.func_literal(expr).lower()}'
74+
f'({self._print(expr.args[0])})')
75+
76+
def _print_Im(self, expr):
77+
"""Print an Im as an access into the second entry of a float array."""
78+
return (f'{self.func_prefix(expr)}imag{self.func_literal(expr).lower()}'
79+
f'({self._print(expr.args[0])})')

devito/passes/iet/languages/CXX.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,13 @@ class CXXPrinter(BasePrinter, CXX11CodePrinter):
106106

107107
def _print_ImaginaryUnit(self, expr):
108108
return f'1i{self.prec_literal(expr).lower()}'
109+
# return '1i'
110+
111+
def _print_Re(self, expr):
112+
return f'{self._ns}real({self._print(expr.args[0])})'
113+
114+
def _print_Im(self, expr):
115+
return f'{self._ns}imag({self._print(expr.args[0])})'
109116

110117
def _print_Cast(self, expr):
111118
# The CXX recommended way to cast is to use static_cast

tests/test_symbolics.py

Lines changed: 121 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from sympy import Expr, Number, Symbol
88
from devito import (Constant, Dimension, Grid, Function, solve, TimeFunction, Eq, # noqa
99
Operator, SubDimension, norm, Le, Ge, Gt, Lt, Abs, sin, cos,
10-
Min, Max, SubDomain)
10+
Min, Max, Re, Im, SubDomain)
1111
from devito.finite_differences.differentiable import SafeInv, Weights, Mul
1212
from devito.ir import Expression, FindNodes, ccode
1313
from devito.symbolics import (retrieve_functions, retrieve_indexed, evalrel, # noqa
@@ -924,3 +924,123 @@ def test_customdtype_complex():
924924

925925
assert not f.is_imaginary
926926
assert f.is_real
927+
928+
929+
class TestComplexParts:
930+
# TODO: Add a cxx switchconfig
931+
def setup_basic(self, dtype):
932+
grid = Grid(shape=(5,), extent=(4.,))
933+
f = Function(name='f', grid=grid, dtype=dtype)
934+
f.data_with_halo[:] = np.arange(7) + 1j*np.arange(7, 14)[::-1]
935+
936+
f_real = Function(name='f_real', grid=grid)
937+
f_imag = Function(name='f_imag', grid=grid)
938+
return f, f_real, f_imag
939+
940+
def run_operator(self, eqs, cxx):
941+
if cxx:
942+
with switchconfig(language='CXX'):
943+
Operator(eqs)()
944+
else:
945+
Operator(eqs)()
946+
947+
@pytest.mark.parametrize('cxx', [False, True])
948+
def test_printing(self, cxx):
949+
f, f_real, f_imag = self.setup_basic(np.complex64)
950+
951+
eq_re = Eq(f_real, Re(f))
952+
eq_im = Eq(f_imag, Im(f))
953+
954+
if cxx:
955+
with switchconfig(language='CXX'):
956+
op = Operator([eq_re, eq_im])
957+
assert "f_real[x + 1] = std::real(f[x + 1])" in str(op.ccode)
958+
assert "f_imag[x + 1] = std::imag(f[x + 1])" in str(op.ccode)
959+
960+
else:
961+
op = Operator([eq_re, eq_im])
962+
assert "f_real[x + 1] = crealf(f[x + 1])" in str(op.ccode)
963+
assert "f_imag[x + 1] = cimagf(f[x + 1])" in str(op.ccode)
964+
965+
@pytest.mark.parametrize('cxx', [False, True])
966+
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
967+
def test_trivial(self, cxx, dtype):
968+
f, f_real, f_imag = self.setup_basic(dtype)
969+
970+
eq_re = Eq(f_real, Re(f+1.))
971+
eq_im = Eq(f_imag, Im(f+1.))
972+
973+
self.run_operator([eq_re, eq_im], cxx)
974+
975+
rcheck = np.array([2., 3., 4., 5., 6.])
976+
icheck = np.array([12., 11., 10., 9., 8.])
977+
assert np.all(np.isclose(f_real.data, rcheck))
978+
assert np.all(np.isclose(f_imag.data, icheck))
979+
980+
@pytest.mark.parametrize('cxx', [False, True])
981+
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
982+
def test_trivial_imag(self, cxx, dtype):
983+
f, f_real, f_imag = self.setup_basic(dtype)
984+
985+
eq_re = Eq(f_real, Re(f+1j))
986+
eq_im = Eq(f_imag, Im(f+1j))
987+
988+
self.run_operator([eq_re, eq_im], cxx)
989+
990+
rcheck = np.array([1., 2., 3., 4., 5.])
991+
icheck = np.array([13., 12., 11., 10., 9.])
992+
assert np.all(np.isclose(f_real.data, rcheck))
993+
assert np.all(np.isclose(f_imag.data, icheck))
994+
995+
@pytest.mark.parametrize('cxx', [False, True])
996+
def test_deriv(self, cxx):
997+
f, f_real, f_imag = self.setup_basic(np.complex64)
998+
999+
eq_re = Eq(f_real, Re(f.dx))
1000+
eq_im = Eq(f_imag, Im(f.dx))
1001+
1002+
self.run_operator([eq_re, eq_im], cxx)
1003+
1004+
assert np.all(np.isclose(f_real.data, 1.))
1005+
assert np.all(np.isclose(f_imag.data, -1.))
1006+
1007+
@pytest.mark.parametrize('cxx', [False, True])
1008+
def test_outer_deriv(self, cxx):
1009+
f, f_real, f_imag = self.setup_basic(np.complex64)
1010+
1011+
eq_re = Eq(f_real, Re(f).dx)
1012+
eq_im = Eq(f_imag, Im(f).dx)
1013+
1014+
self.run_operator([eq_re, eq_im], cxx)
1015+
1016+
assert np.all(np.isclose(f_real.data, 1.))
1017+
assert np.all(np.isclose(f_imag.data, -1.))
1018+
1019+
@pytest.mark.parametrize('cxx', [False, True])
1020+
def test_mul(self, cxx):
1021+
grid = Grid(shape=(5,))
1022+
1023+
f = Function(name='f', grid=grid, dtype=np.complex64)
1024+
g = Function(name='g', grid=grid)
1025+
h = Function(name='h', grid=grid, dtype=np.complex64)
1026+
f.data[:] = 1 + 1j
1027+
g.data[:] = 2
1028+
h.data[:] = 2j
1029+
1030+
fg_re = Function(name='fg_re', grid=grid)
1031+
fg_im = Function(name='fg_im', grid=grid)
1032+
fh_re = Function(name='fh_re', grid=grid)
1033+
fh_im = Function(name='fh_im', grid=grid)
1034+
1035+
eq_fg_re = Eq(fg_re, Re(f*g))
1036+
eq_fg_im = Eq(fg_im, Im(f*g))
1037+
eq_fh_re = Eq(fh_re, Re(f*h))
1038+
eq_fh_im = Eq(fh_im, Im(f*h))
1039+
1040+
self.run_operator([eq_fg_re, eq_fg_im, eq_fh_re, eq_fh_im], cxx)
1041+
1042+
assert np.all(np.isclose(fg_re.data, 2.))
1043+
assert np.all(np.isclose(fg_im.data, 2.))
1044+
1045+
assert np.all(np.isclose(fh_re.data, -2.))
1046+
assert np.all(np.isclose(fh_im.data, 2.))

0 commit comments

Comments
 (0)