Skip to content

Commit b0e35d7

Browse files
committed
compiler: imporve handling of dtype in cse
1 parent 6dd766f commit b0e35d7

5 files changed

Lines changed: 33 additions & 24 deletions

File tree

devito/finite_differences/differentiable.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -654,27 +654,29 @@ def __new__(cls, *args, **kwargs):
654654
raise ValueError(f"{cls.__name__} is constructed with exactly one arg;"
655655
f" {len(args)} were supplied.")
656656

657-
# Diffify any Add, Mul, etc which might be in the expression
658-
new_args = (diffify(args[0]),)
659-
660-
if not np.issubdtype(new_args[0].dtype, np.complexfloating):
661-
raise ValueError(f"{cls.__name__} requires a complex dtype,"
662-
f" not {new_args[0].dtype.__name__}.")
663-
664-
return super().__new__(cls, *new_args, **kwargs)
657+
return super().__new__(cls, *args, **kwargs)
665658

666659
def __str__(self):
667660
return f"{self.__class__.__name__}({self.args[0]})"
668661

669662
__repr__ = __str__
670663

671664

672-
class Real(ComplexPart):
665+
class RealComplexPart(ComplexPart):
666+
667+
@cached_property
668+
def dtype(self):
669+
dtypes = {getattr(e, 'dtype', None) for e in self.free_symbols}
670+
dtype = infer_dtype(dtypes - {None})
671+
return dtype(0).real.__class__
672+
673+
674+
class Real(RealComplexPart):
673675
"""Get the real part of an expression"""
674676
_name = 'real'
675677

676678

677-
class Imag(ComplexPart):
679+
class Imag(RealComplexPart):
678680
"""Get the imaginary part of an expression"""
679681
_name = 'imag'
680682

devito/ir/clusters/cluster.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -304,12 +304,13 @@ def dtype(self):
304304
"""
305305
dtypes = set()
306306
for i in self.exprs:
307-
try:
308-
if np.issubdtype(i.dtype, np.generic):
309-
dtypes.add(i.dtype)
310-
except TypeError:
311-
# E.g. `i.dtype` is a ctypes pointer, which has no dtype equivalent
312-
pass
307+
# try:
308+
if np.issubdtype(i.dtype, np.generic):
309+
dtypes.add(i.dtype)
310+
# except TypeError:
311+
# print(i, type(i), i.dtype, np.issubdtype(i.dtype, np.generic))
312+
# # E.g. `i.dtype` is a ctypes pointer, which has no dtype equivalent
313+
# pass
313314

314315
return infer_dtype(dtypes)
315316

devito/passes/clusters/aliases.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -888,7 +888,6 @@ def lower_schedule(schedule, meta, sregistry, ftemps, min_dtype):
888888
else:
889889
# Degenerate case: scalar expression
890890
assert writeto.size == 0
891-
892891
dtype = sympy_dtype(pivot, base=meta.dtype, smin=min_dtype)
893892
obj = Temp(name=name, dtype=dtype)
894893
expression = Eq(obj, uxreplace(pivot, subs))
@@ -1354,7 +1353,7 @@ def cost(self):
13541353
# Not just the sum for the individual items' cost! There might be
13551354
# redundancies, which we factor out here...
13561355
counter = generator()
1357-
make = lambda: Symbol(name='dummy%d' % counter(), dtype=np.float32)
1356+
make = lambda _: Symbol(name='dummy%d' % counter(), dtype=np.float32)
13581357

13591358
tot = 0
13601359
for v in as_mapper(self, lambda i: i.ispace).values():

devito/passes/clusters/cse.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from devito.ir import Cluster, Scope, cluster_pass
1515
from devito.symbolics import estimate_cost, q_leaf, q_terminal
1616
from devito.symbolics.manipulation import _uxreplace
17-
from devito.tools import DAG, as_list, as_tuple, frozendict
17+
from devito.tools import DAG, as_list, as_tuple, frozendict, infer_dtype
1818
from devito.types import Eq, Symbol, Temp
1919

2020
__all__ = ['cse']
@@ -78,7 +78,8 @@ def cse(cluster, sregistry=None, options=None, **kwargs):
7878
if cluster.is_fence:
7979
return cluster
8080

81-
make = lambda: CTemp(name=sregistry.make_name(), dtype=dtype)
81+
make_dtype = lambda e: np.promote_types(e.dtype, dtype).type
82+
make = lambda e: CTemp(name=sregistry.make_name(), dtype=make_dtype(e))
8283

8384
exprs = _cse(cluster, make, min_cost=min_cost, mode=mode)
8485

@@ -118,7 +119,7 @@ def _cse(maybe_exprs, make, min_cost=1, mode='basic'):
118119
exprs = maybe_exprs
119120
scope = Scope(maybe_exprs)
120121
else:
121-
exprs = [Eq(make(), e) for e in maybe_exprs]
122+
exprs = [Eq(make(e), e) for e in maybe_exprs]
122123
scope = Scope([])
123124

124125
# Some sub-expressions aren't really "common" -- that's the case of Dimension-
@@ -155,7 +156,7 @@ def _cse(maybe_exprs, make, min_cost=1, mode='basic'):
155156
candidates = [c for c in candidates if c.cost == cost]
156157

157158
# Apply replacements
158-
chosen = [(c, scheduled.get(c.key) or make()) for c in candidates]
159+
chosen = [(c, scheduled.get(c.key) or make(c)) for c in candidates]
159160
exprs = _inject(exprs, chosen, scheduled)
160161

161162
# Drop useless temporaries (e.g., r0=r1)
@@ -275,6 +276,12 @@ def __new__(cls, expr, conditionals=None, sources=()):
275276
def expr(self):
276277
return self[0]
277278

279+
@property
280+
def dtype(self):
281+
dtypes = {getattr(e, 'dtype', None)
282+
for e in self.expr.free_symbols}
283+
return infer_dtype(dtypes - {None})
284+
278285
@property
279286
def conditionals(self):
280287
return self[1]

tests/test_cse.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def test_default_algo(exprs, expected, min_cost):
109109
exprs[i] = DummyEq(indexify(diffify(eval(e).evaluate)))
110110

111111
counter = generator()
112-
make = lambda: CTemp(name='r%d' % counter()).indexify()
112+
make = lambda _: CTemp(name='r%d' % counter()).indexify()
113113
processed = _cse(exprs, make, min_cost)
114114

115115
assert len(processed) == len(expected)
@@ -241,7 +241,7 @@ def test_advanced_algo(exprs, expected):
241241
exprs[i] = DummyEq(indexify(diffify(eval(e).evaluate)))
242242

243243
counter = generator()
244-
make = lambda: CTemp(name='r%d' % counter(), dtype=np.float32).indexify()
244+
make = lambda _: CTemp(name='r%d' % counter(), dtype=np.float32).indexify()
245245
processed = _cse(exprs, make, mode='advanced')
246246

247247
assert len(processed) == len(expected)

0 commit comments

Comments
 (0)