|
14 | 14 | from devito.ir import Cluster, Scope, cluster_pass |
15 | 15 | from devito.symbolics import estimate_cost, q_leaf, q_terminal |
16 | 16 | from devito.symbolics.manipulation import _uxreplace |
17 | | -from devito.tools import DAG, as_list, as_tuple, frozendict |
| 17 | +from devito.tools import DAG, as_list, as_tuple, frozendict, infer_dtype |
18 | 18 | from devito.types import Eq, Symbol, Temp |
19 | 19 |
|
20 | 20 | __all__ = ['cse'] |
@@ -78,7 +78,8 @@ def cse(cluster, sregistry=None, options=None, **kwargs): |
78 | 78 | if cluster.is_fence: |
79 | 79 | return cluster |
80 | 80 |
|
81 | | - make = lambda: CTemp(name=sregistry.make_name(), dtype=dtype) |
| 81 | + make_dtype = lambda e: np.promote_types(e.dtype, dtype).type |
| 82 | + make = lambda e: CTemp(name=sregistry.make_name(), dtype=make_dtype(e)) |
82 | 83 |
|
83 | 84 | exprs = _cse(cluster, make, min_cost=min_cost, mode=mode) |
84 | 85 |
|
@@ -118,7 +119,7 @@ def _cse(maybe_exprs, make, min_cost=1, mode='basic'): |
118 | 119 | exprs = maybe_exprs |
119 | 120 | scope = Scope(maybe_exprs) |
120 | 121 | else: |
121 | | - exprs = [Eq(make(), e) for e in maybe_exprs] |
| 122 | + exprs = [Eq(make(e), e) for e in maybe_exprs] |
122 | 123 | scope = Scope([]) |
123 | 124 |
|
124 | 125 | # Some sub-expressions aren't really "common" -- that's the case of Dimension- |
@@ -155,7 +156,7 @@ def _cse(maybe_exprs, make, min_cost=1, mode='basic'): |
155 | 156 | candidates = [c for c in candidates if c.cost == cost] |
156 | 157 |
|
157 | 158 | # Apply replacements |
158 | | - chosen = [(c, scheduled.get(c.key) or make()) for c in candidates] |
| 159 | + chosen = [(c, scheduled.get(c.key) or make(c)) for c in candidates] |
159 | 160 | exprs = _inject(exprs, chosen, scheduled) |
160 | 161 |
|
161 | 162 | # Drop useless temporaries (e.g., r0=r1) |
@@ -275,6 +276,12 @@ def __new__(cls, expr, conditionals=None, sources=()): |
275 | 276 | def expr(self): |
276 | 277 | return self[0] |
277 | 278 |
|
| 279 | + @property |
| 280 | + def dtype(self): |
| 281 | + dtypes = {getattr(e, 'dtype', None) |
| 282 | + for e in self.expr.free_symbols} |
| 283 | + return infer_dtype(dtypes - {None}) |
| 284 | + |
278 | 285 | @property |
279 | 286 | def conditionals(self): |
280 | 287 | return self[1] |
|
0 commit comments