|
6 | 6 | from devito.tools import (Ordering, as_tuple, flatten, filter_sorted, filter_ordered, |
7 | 7 | frozendict) |
8 | 8 | from devito.types import (Dimension, Eq, IgnoreDimSort, SubDimension, |
9 | | - ConditionalDimension) |
| 9 | + ConditionalDimension, MultiStage) |
10 | 10 | from devito.types.array import Array |
11 | 11 | from devito.types.basic import AbstractFunction |
12 | 12 | from devito.types.dimension import MultiSubDimension, Thickness |
13 | 13 | from devito.data.allocators import DataReference |
14 | 14 | from devito.logger import warning |
15 | 15 |
|
16 | | -__all__ = ['dimension_sort', 'lower_exprs', 'concretize_subdims'] |
| 16 | + |
| 17 | +__all__ = ['dimension_sort', 'lower_multistage', 'lower_exprs', 'concretize_subdims'] |
17 | 18 |
|
18 | 19 |
|
19 | 20 | def dimension_sort(expr): |
@@ -95,6 +96,39 @@ def handle_indexed(indexed): |
95 | 96 | return ordering |
96 | 97 |
|
97 | 98 |
|
| 99 | +def lower_multistage(expressions, **kwargs): |
| 100 | + """ |
| 101 | + Separating the multi-stage time-integrator scheme in stages: |
| 102 | + * If the object is MultiStage, it creates the stages of the method. |
| 103 | + """ |
| 104 | + return _lower_multistage(expressions, **kwargs) |
| 105 | + |
| 106 | + |
| 107 | +@singledispatch |
| 108 | +def _lower_multistage(expr, **kwargs): |
| 109 | + """ |
| 110 | + Default handler for expressions that are not MultiStage. |
| 111 | + Simply return them in a list. |
| 112 | + """ |
| 113 | + return [expr] |
| 114 | + |
| 115 | + |
| 116 | +@_lower_multistage.register(MultiStage) |
| 117 | +def _(expr, **kwargs): |
| 118 | + """ |
| 119 | + Specialized handler for MultiStage expressions. |
| 120 | + """ |
| 121 | + return expr._evaluate(**kwargs) |
| 122 | + |
| 123 | + |
| 124 | +@_lower_multistage.register(Iterable) |
| 125 | +def _(exprs, **kwargs): |
| 126 | + """ |
| 127 | + Handle iterables of expressions. |
| 128 | + """ |
| 129 | + return sum([_lower_multistage(expr, **kwargs) for expr in exprs], []) |
| 130 | + |
| 131 | + |
98 | 132 | def lower_exprs(expressions, subs=None, **kwargs): |
99 | 133 | """ |
100 | 134 | Lowering an expression consists of the following passes: |
|
0 commit comments