Skip to content

Commit 552fd7f

Browse files
committed
Isolate multistage-related changes only
1 parent 6139faa commit 552fd7f

6 files changed

Lines changed: 1137 additions & 7 deletions

File tree

devito/ir/equations/algorithms.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66
from devito.tools import (Ordering, as_tuple, flatten, filter_sorted, filter_ordered,
77
frozendict)
88
from devito.types import (Dimension, Eq, IgnoreDimSort, SubDimension,
9-
ConditionalDimension)
9+
ConditionalDimension, MultiStage)
1010
from devito.types.array import Array
1111
from devito.types.basic import AbstractFunction
1212
from devito.types.dimension import MultiSubDimension, Thickness
1313
from devito.data.allocators import DataReference
1414
from devito.logger import warning
1515

16-
__all__ = ['dimension_sort', 'lower_exprs', 'concretize_subdims']
16+
17+
__all__ = ['dimension_sort', 'lower_multistage', 'lower_exprs', 'concretize_subdims']
1718

1819

1920
def dimension_sort(expr):
@@ -95,6 +96,39 @@ def handle_indexed(indexed):
9596
return ordering
9697

9798

99+
def lower_multistage(expressions, **kwargs):
100+
"""
101+
Separating the multi-stage time-integrator scheme in stages:
102+
* If the object is MultiStage, it creates the stages of the method.
103+
"""
104+
return _lower_multistage(expressions, **kwargs)
105+
106+
107+
@singledispatch
108+
def _lower_multistage(expr, **kwargs):
109+
"""
110+
Default handler for expressions that are not MultiStage.
111+
Simply return them in a list.
112+
"""
113+
return [expr]
114+
115+
116+
@_lower_multistage.register(MultiStage)
117+
def _(expr, **kwargs):
118+
"""
119+
Specialized handler for MultiStage expressions.
120+
"""
121+
return expr._evaluate(**kwargs)
122+
123+
124+
@_lower_multistage.register(Iterable)
125+
def _(exprs, **kwargs):
126+
"""
127+
Handle iterables of expressions.
128+
"""
129+
return sum([_lower_multistage(expr, **kwargs) for expr in exprs], [])
130+
131+
98132
def lower_exprs(expressions, subs=None, **kwargs):
99133
"""
100134
Lowering an expression consists of the following passes:

devito/operations/solve.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from devito.finite_differences.derivative import Derivative
88
from devito.tools import as_tuple
99

10+
from devito.types.multistage import resolve_method
11+
1012
__all__ = ['solve', 'linsolve']
1113

1214

@@ -15,7 +17,7 @@ class SolveError(Exception):
1517
pass
1618

1719

18-
def solve(eq, target, **kwargs):
20+
def solve(eq, target, method = None, eq_num = 0, **kwargs):
1921
"""
2022
Algebraically rearrange an Eq w.r.t. a given symbol.
2123
@@ -56,9 +58,12 @@ def solve(eq, target, **kwargs):
5658

5759
# We need to rebuild the vector/tensor as sympy.solve outputs a tuple of solutions
5860
if len(sols) > 1:
59-
return target.new_from_mat(sols)
61+
sols_temp = target.new_from_mat(sols)
6062
else:
61-
return sols[0]
63+
sols_temp = sols[0]
64+
65+
method = kwargs.get("method", None)
66+
return sols_temp if method is None else resolve_method(method)(target, sols_temp)
6267

6368

6469
def linsolve(expr, target, **kwargs):

devito/operator/operator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
InvalidOperator)
1818
from devito.logger import (debug, info, perf, warning, is_log_enabled_for,
1919
switch_log_level)
20-
from devito.ir.equations import LoweredEq, lower_exprs, concretize_subdims
20+
from devito.ir.equations import LoweredEq, lower_multistage, lower_exprs, concretize_subdims
2121
from devito.ir.clusters import ClusterGroup, clusterize
2222
from devito.ir.iet import (Callable, CInterface, EntryFunction, DeviceFunction,
2323
FindSymbols, MetaCall, derive_parameters, iet_build)
@@ -40,7 +40,6 @@
4040
disk_layer)
4141
from devito.types.dimension import Thickness
4242

43-
4443
__all__ = ['Operator']
4544

4645

@@ -337,6 +336,8 @@ def _lower_exprs(cls, expressions, **kwargs):
337336
* Apply substitution rules;
338337
* Shift indices for domain alignment.
339338
"""
339+
expressions = lower_multistage(expressions, **kwargs)
340+
340341
expand = kwargs['options'].get('expand', True)
341342

342343
# Specialization is performed on unevaluated expressions

devito/types/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,6 @@
2222
from .relational import * # noqa
2323
from .sparse import * # noqa
2424
from .tensor import * # noqa
25+
26+
from .multistage import * # noqa
27+
from .multistage_new import * # noqa

0 commit comments

Comments
 (0)