Skip to content

Commit cb2ed5a

Browse files
committed
dsl: Add Re and Im operators for taking real and imaginary parts of an expression
1 parent d6980b7 commit cb2ed5a

5 files changed

Lines changed: 179 additions & 4 deletions

File tree

devito/finite_differences/differentiable.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from devito.types.basic import AbstractFunction
2323

2424
__all__ = ['Differentiable', 'DiffDerivative', 'IndexDerivative', 'EvalDerivative',
25-
'Weights']
25+
'Weights', 'Re', 'Im']
2626

2727

2828
class Differentiable(sympy.Expr, Evaluable):
@@ -644,6 +644,39 @@ def __str__(self):
644644
__repr__ = __str__
645645

646646

647+
class ComplexPart(Differentiable, sympy.core.function.Application):
648+
"""Abstract class for `Re` or `Im` of an expression"""
649+
650+
def __new__(cls, *args, **kwargs):
651+
if len(args) != 1:
652+
raise ValueError(f"{cls.__name__} is constructed with exactly one arg;"
653+
f" {len(args)} were supplied.")
654+
655+
# Diffify any Add, Mul, etc which might be in the expression
656+
new_args = (diffify(args[0]),)
657+
658+
if not np.issubdtype(new_args[0].dtype, np.complexfloating):
659+
raise ValueError(f"{cls.__name__} requires a complex dtype,"
660+
f" not {new_args[0].dtype.__name__}.")
661+
662+
return super().__new__(cls, *new_args, **kwargs)
663+
664+
def __str__(self):
665+
return f"{self.__class__.__name__}({self.args[0]})"
666+
667+
__repr__ = __str__
668+
669+
670+
class Re(ComplexPart):
671+
"""Get the real part of an expression"""
672+
pass
673+
674+
675+
class Im(ComplexPart):
676+
"""Get the imaginary part of an expression"""
677+
pass
678+
679+
647680
class IndexSum(sympy.Expr, Evaluable):
648681

649682
"""

devito/ir/equations/equation.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from devito.ir.support import (GuardFactor, Interval, IntervalGroup, IterationSpace,
99
Stencil, detect_io, detect_accesses)
1010
from devito.symbolics import IntDiv, limits_mapper, uxreplace
11-
from devito.tools import Pickable, Tag, frozendict
11+
from devito.tools import Pickable, Tag, frozendict, infer_dtype
1212
from devito.types import Eq, Inc, ReduceMax, ReduceMin, relational_min
1313

1414
__all__ = ['LoweredEq', 'ClusterizedEq', 'DummyEq', 'OpInc', 'OpMin', 'OpMax',
@@ -48,7 +48,12 @@ def directions(self):
4848

4949
@property
5050
def dtype(self):
51-
return self.lhs.dtype
51+
try:
52+
rhs_dtype = self.rhs.dtype
53+
except AttributeError:
54+
rhs_dtype = None
55+
56+
return infer_dtype({self.lhs.dtype, rhs_dtype} - {None})
5257

5358
@property
5459
def state(self):

devito/passes/iet/languages/C.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,13 @@ def _print_ListInitializer(self, expr):
6767
return f'({tstr}[]){li}'
6868
else:
6969
return li
70+
71+
def _print_Re(self, expr):
72+
"""Print an Re as an access into the second entry of a float array."""
73+
return (f'{self.func_prefix(expr)}real{self.func_literal(expr).lower()}'
74+
f'({self._print(expr.args[0])})')
75+
76+
def _print_Im(self, expr):
77+
"""Print an Im as an access into the second entry of a float array."""
78+
return (f'{self.func_prefix(expr)}imag{self.func_literal(expr).lower()}'
79+
f'({self._print(expr.args[0])})')

devito/passes/iet/languages/CXX.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,13 @@ class CXXPrinter(BasePrinter, CXX11CodePrinter):
106106

107107
def _print_ImaginaryUnit(self, expr):
108108
return f'1i{self.prec_literal(expr).lower()}'
109+
# return '1i'
110+
111+
def _print_Re(self, expr):
112+
return f'{self._ns}real({self._print(expr.args[0])})'
113+
114+
def _print_Im(self, expr):
115+
return f'{self._ns}imag({self._print(expr.args[0])})'
109116

110117
def _print_Cast(self, expr):
111118
# The CXX recommended way to cast is to use static_cast

tests/test_symbolics.py

Lines changed: 121 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from sympy import Expr, Number, Symbol
88
from devito import (Constant, Dimension, Grid, Function, solve, TimeFunction, Eq, # noqa
99
Operator, SubDimension, norm, Le, Ge, Gt, Lt, Abs, sin, cos,
10-
Min, Max, SubDomain)
10+
Min, Max, Re, Im, SubDomain)
1111
from devito.finite_differences.differentiable import SafeInv, Weights, Mul
1212
from devito.ir import Expression, FindNodes, ccode
1313
from devito.symbolics import (retrieve_functions, retrieve_indexed, evalrel, # noqa
@@ -912,3 +912,123 @@ def test_print_div():
912912
b = SizeOf(np.int64)
913913
cstr = ccode(a / b)
914914
assert cstr == 'sizeof(int)/sizeof(long)'
915+
916+
917+
class TestComplexParts:
918+
# TODO: Add a cxx switchconfig
919+
def setup_basic(self, dtype):
920+
grid = Grid(shape=(5,), extent=(4.,))
921+
f = Function(name='f', grid=grid, dtype=dtype)
922+
f.data_with_halo[:] = np.arange(7) + 1j*np.arange(7, 14)[::-1]
923+
924+
f_real = Function(name='f_real', grid=grid)
925+
f_imag = Function(name='f_imag', grid=grid)
926+
return f, f_real, f_imag
927+
928+
def run_operator(self, eqs, cxx):
929+
if cxx:
930+
with switchconfig(language='CXX'):
931+
Operator(eqs)()
932+
else:
933+
Operator(eqs)()
934+
935+
@pytest.mark.parametrize('cxx', [False, True])
936+
def test_printing(self, cxx):
937+
f, f_real, f_imag = self.setup_basic(np.complex64)
938+
939+
eq_re = Eq(f_real, Re(f))
940+
eq_im = Eq(f_imag, Im(f))
941+
942+
if cxx:
943+
with switchconfig(language='CXX'):
944+
op = Operator([eq_re, eq_im])
945+
assert "f_real[x + 1] = std::real(f[x + 1])" in str(op.ccode)
946+
assert "f_imag[x + 1] = std::imag(f[x + 1])" in str(op.ccode)
947+
948+
else:
949+
op = Operator([eq_re, eq_im])
950+
assert "f_real[x + 1] = crealf(f[x + 1])" in str(op.ccode)
951+
assert "f_imag[x + 1] = cimagf(f[x + 1])" in str(op.ccode)
952+
953+
@pytest.mark.parametrize('cxx', [False, True])
954+
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
955+
def test_trivial(self, cxx, dtype):
956+
f, f_real, f_imag = self.setup_basic(dtype)
957+
958+
eq_re = Eq(f_real, Re(f+1.))
959+
eq_im = Eq(f_imag, Im(f+1.))
960+
961+
self.run_operator([eq_re, eq_im], cxx)
962+
963+
rcheck = np.array([2., 3., 4., 5., 6.])
964+
icheck = np.array([12., 11., 10., 9., 8.])
965+
assert np.all(np.isclose(f_real.data, rcheck))
966+
assert np.all(np.isclose(f_imag.data, icheck))
967+
968+
@pytest.mark.parametrize('cxx', [False, True])
969+
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
970+
def test_trivial_imag(self, cxx, dtype):
971+
f, f_real, f_imag = self.setup_basic(dtype)
972+
973+
eq_re = Eq(f_real, Re(f+1j))
974+
eq_im = Eq(f_imag, Im(f+1j))
975+
976+
self.run_operator([eq_re, eq_im], cxx)
977+
978+
rcheck = np.array([1., 2., 3., 4., 5.])
979+
icheck = np.array([13., 12., 11., 10., 9.])
980+
assert np.all(np.isclose(f_real.data, rcheck))
981+
assert np.all(np.isclose(f_imag.data, icheck))
982+
983+
@pytest.mark.parametrize('cxx', [False, True])
984+
def test_deriv(self, cxx):
985+
f, f_real, f_imag = self.setup_basic(np.complex64)
986+
987+
eq_re = Eq(f_real, Re(f.dx))
988+
eq_im = Eq(f_imag, Im(f.dx))
989+
990+
self.run_operator([eq_re, eq_im], cxx)
991+
992+
assert np.all(np.isclose(f_real.data, 1.))
993+
assert np.all(np.isclose(f_imag.data, -1.))
994+
995+
@pytest.mark.parametrize('cxx', [False, True])
996+
def test_outer_deriv(self, cxx):
997+
f, f_real, f_imag = self.setup_basic(np.complex64)
998+
999+
eq_re = Eq(f_real, Re(f).dx)
1000+
eq_im = Eq(f_imag, Im(f).dx)
1001+
1002+
self.run_operator([eq_re, eq_im], cxx)
1003+
1004+
assert np.all(np.isclose(f_real.data, 1.))
1005+
assert np.all(np.isclose(f_imag.data, -1.))
1006+
1007+
@pytest.mark.parametrize('cxx', [False, True])
1008+
def test_mul(self, cxx):
1009+
grid = Grid(shape=(5,))
1010+
1011+
f = Function(name='f', grid=grid, dtype=np.complex64)
1012+
g = Function(name='g', grid=grid)
1013+
h = Function(name='h', grid=grid, dtype=np.complex64)
1014+
f.data[:] = 1 + 1j
1015+
g.data[:] = 2
1016+
h.data[:] = 2j
1017+
1018+
fg_re = Function(name='fg_re', grid=grid)
1019+
fg_im = Function(name='fg_im', grid=grid)
1020+
fh_re = Function(name='fh_re', grid=grid)
1021+
fh_im = Function(name='fh_im', grid=grid)
1022+
1023+
eq_fg_re = Eq(fg_re, Re(f*g))
1024+
eq_fg_im = Eq(fg_im, Im(f*g))
1025+
eq_fh_re = Eq(fh_re, Re(f*h))
1026+
eq_fh_im = Eq(fh_im, Im(f*h))
1027+
1028+
self.run_operator([eq_fg_re, eq_fg_im, eq_fh_re, eq_fh_im], cxx)
1029+
1030+
assert np.all(np.isclose(fg_re.data, 2.))
1031+
assert np.all(np.isclose(fg_im.data, 2.))
1032+
1033+
assert np.all(np.isclose(fh_re.data, -2.))
1034+
assert np.all(np.isclose(fh_im.data, 2.))

0 commit comments

Comments
 (0)