-
Notifications
You must be signed in to change notification settings - Fork 255
dsl: Introduce abstractions for multi-stage time integrators #2599
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
1d830b8
7f087b3
214d882
d6c4d4a
78f8a0b
1c9d517
11db48b
83dfb04
d47a106
1f93a45
eea3a52
11d1429
4637ac2
ac1da7e
e9b3533
dc3dd77
1fd4a02
a0c45c1
93c6e3f
ef8d1ac
fa5acac
143d0c2
5f67b91
552fd7f
e9d2000
f7c9ea3
cf1003c
1fd480b
a875224
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,7 +13,9 @@ | |
| from devito.data.allocators import DataReference | ||
| from devito.logger import warning | ||
|
|
||
| __all__ = ['dimension_sort', 'lower_exprs', 'concretize_subdims'] | ||
| from devito.types.multistage import MultiStage | ||
|
|
||
| __all__ = ['dimension_sort', 'lower_multistage', 'lower_exprs', 'concretize_subdims'] | ||
|
|
||
|
|
||
| def dimension_sort(expr): | ||
|
|
@@ -95,6 +97,34 @@ def handle_indexed(indexed): | |
| return ordering | ||
|
|
||
|
|
||
| def lower_multistage(expressions): | ||
| """ | ||
| Separating the multi-stage time-integrator scheme in stages: | ||
| * If the object is MultiStage, it creates the stages of the method. | ||
| """ | ||
| lowered = [] | ||
| for i, eq in enumerate(as_tuple(expressions)): | ||
| lowered.extend(_lower_multistage(eq, i)) | ||
| return lowered | ||
|
|
||
|
|
||
| @singledispatch | ||
| def _lower_multistage(expr, index): | ||
| """ | ||
| Default handler for expressions that are not MultiStage. | ||
| Simply return them in a list. | ||
| """ | ||
| return [expr] | ||
|
|
||
|
|
||
| @_lower_multistage.register | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would personally tweak this for consistency with other uses of
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think now it's done |
||
| def _(expr: MultiStage, index): | ||
| """ | ||
| Specialized handler for MultiStage expressions. | ||
| """ | ||
| return expr.method(expr.eq.rhs, expr.eq.lhs)._evaluate(eq_num=index) | ||
|
|
||
|
|
||
| def lower_exprs(expressions, subs=None, **kwargs): | ||
| """ | ||
| Lowering an expression consists of the following passes: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,6 +7,8 @@ | |
| from devito.finite_differences.derivative import Derivative | ||
| from devito.tools import as_tuple | ||
|
|
||
| from devito.types.multistage import MultiStage | ||
|
|
||
| __all__ = ['solve', 'linsolve'] | ||
|
|
||
|
|
||
|
|
@@ -15,7 +17,7 @@ class SolveError(Exception): | |
| pass | ||
|
|
||
|
|
||
| def solve(eq, target, **kwargs): | ||
| def solve(eq, target, method = None, eq_num = 0, **kwargs): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Kwargs should not have spaces around them. Furthermore, can
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see |
||
| """ | ||
| Algebraically rearrange an Eq w.r.t. a given symbol. | ||
|
|
||
|
|
@@ -56,9 +58,15 @@ def solve(eq, target, **kwargs): | |
|
|
||
| # We need to rebuild the vector/tensor as sympy.solve outputs a tuple of solutions | ||
| if len(sols) > 1: | ||
| return target.new_from_mat(sols) | ||
| sols_temp=target.new_from_mat(sols) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should have whitespace around operator. Same below
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
| else: | ||
| sols_temp=sols[0] | ||
|
|
||
| if method is not None: | ||
| method_cls = MultiStage._resolve_method(method) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To me, this implies that method_cls = eq._resolve_method() # Possibly make this a property?
if method_cls is None:
return sols_temp
return method_cls(sols_temp, target)._evaluate(eq_num=eq_num)or even just return eq._resolve_method(sols_temp, target)._evaluate(eq_num=eq_num)where As a side note, why is the
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, |
||
| return method_cls(sols_temp, target)._evaluate(eq_num=eq_num) | ||
| else: | ||
| return sols[0] | ||
| return sols_temp | ||
|
|
||
|
|
||
| def linsolve(expr, target, **kwargs): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,7 +15,7 @@ | |
| InvalidOperator) | ||
| from devito.logger import (debug, info, perf, warning, is_log_enabled_for, | ||
| switch_log_level) | ||
| from devito.ir.equations import LoweredEq, lower_exprs, concretize_subdims | ||
| from devito.ir.equations import LoweredEq, lower_multistage, lower_exprs, concretize_subdims | ||
| from devito.ir.clusters import ClusterGroup, clusterize | ||
| from devito.ir.iet import (Callable, CInterface, EntryFunction, FindSymbols, | ||
| MetaCall, derive_parameters, iet_build) | ||
|
|
@@ -36,7 +36,6 @@ | |
| disk_layer) | ||
| from devito.types.dimension import Thickness | ||
|
|
||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please run the linter (
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
| __all__ = ['Operator'] | ||
|
|
||
|
|
||
|
|
@@ -327,6 +326,8 @@ def _lower_exprs(cls, expressions, **kwargs): | |
| * Apply substitution rules; | ||
| * Shift indices for domain alignment. | ||
| """ | ||
| expressions=lower_multistage(expressions) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Whitespace missing around operator - check throughout the PR for this
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
|
||
| expand = kwargs['options'].get('expand', True) | ||
|
|
||
| # Specialization is performed on unevaluated expressions | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,3 +22,5 @@ | |
| from .relational import * # noqa | ||
| from .sparse import * # noqa | ||
| from .tensor import * # noqa | ||
|
|
||
| from .multistage import * | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will need a
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this file should be moved to somewhere like
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I renamed the class to Regarding the file location, it’s currently in
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree this file doesn't belong to based on https://github.com/devitocodes/devito/pull/2599/changes#r3043562368, we might add it to |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,186 @@ | ||||||
| # from devito import Function, Eq | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Leftover
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||||||
| from .equation import Eq | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make these imports absolute (
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||||||
| from .dense import Function | ||||||
| from devito.symbolics import uxreplace | ||||||
|
|
||||||
| from .array import Array # Trying Array | ||||||
|
|
||||||
|
|
||||||
| class MultiStage(Eq): | ||||||
| """ | ||||||
| Abstract base class for multi-stage time integration methods | ||||||
| (e.g., Runge-Kutta schemes) in Devito. | ||||||
|
|
||||||
| This class wraps a symbolic equation of the form `target = rhs` and | ||||||
| provides a mechanism to associate a time integration scheme via the | ||||||
| `method` argument. Subclasses must implement the `_evaluate` method to | ||||||
| generate stage-wise update expressions. | ||||||
|
|
||||||
| Parameters | ||||||
| ---------- | ||||||
| rhs : expr-like | ||||||
| The right-hand side of the equation to integrate. | ||||||
| target : Function | ||||||
| The time-updated symbol on the left-hand side, e.g., `u` or `u.forward`. | ||||||
| method : str or None | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be either a class or a callable imo. Alternatively, it should be entirely omitted and set by defining some In general, if you are using a string comparison, there is probably a better (and safer) way to achieve your aim.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry for that, I removed 'method' from the class, but didn't update the docstring... fixing it now |
||||||
| A string identifying the time integration method (e.g., 'RK44'), | ||||||
| which must correspond to a class defined in the global scope and | ||||||
| implementing `_evaluate`. If None, no method is applied. | ||||||
|
|
||||||
| Attributes | ||||||
| ---------- | ||||||
| eq : Eq | ||||||
| The symbolic equation `target = rhs`. | ||||||
| method : class | ||||||
| The integration method class resolved from the `method` string. | ||||||
| """ | ||||||
|
|
||||||
| def __new__(cls, rhs, target, method=None): | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is going to strip subdomain information etc, and the API is inconsistent with the standard def __new__(cls, lhs, rhs=0, method=None, subdomain=None, coefficients=None, implicit_dims=None, **kwargs):
obj = super().__new__(lhs, rhs=rhs, subdomain=subdomain, coefficients=coefficients, implicit_dims=implicit_dims, **kwargs)
obj._method = method # NOTE: Have `_resolve_method` as a cached_property or similar based on some processing of `_method`
return obj
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||||||
| eq = Eq(target, rhs) | ||||||
| obj = Eq.__new__(cls, eq.lhs, eq.rhs) | ||||||
| obj._eq = eq | ||||||
| obj._method = cls._resolve_method(method) | ||||||
| return obj | ||||||
|
|
||||||
| @classmethod | ||||||
| def _resolve_method(cls, method): | ||||||
| try: | ||||||
| if cls is MultiStage: | ||||||
| return globals()[method] | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We generally try to avoid
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||||||
| else: | ||||||
| return cls | ||||||
| except KeyError: | ||||||
| raise ValueError(f"The time integrator '{method}' is not implemented.") | ||||||
|
|
||||||
| @property | ||||||
| def eq(self): | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can be dropped with the restructured
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||||||
| return self._eq | ||||||
|
|
||||||
| @property | ||||||
| def method(self): | ||||||
| return self._method | ||||||
|
|
||||||
| def _evaluate(self, expand=False): | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should probably take
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||||||
| raise NotImplementedError( | ||||||
| f"_evaluate() must be implemented in the subclass {self.__class__.__name__}") | ||||||
|
|
||||||
|
|
||||||
| class RK(MultiStage): | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||||||
| """ | ||||||
| Base class for explicit Runge-Kutta (RK) time integration methods defined | ||||||
| via a Butcher tableau. | ||||||
|
|
||||||
| This class handles the general structure of RK schemes by using | ||||||
| the Butcher coefficients (`a`, `b`, `c`) to expand a single equation into | ||||||
| a series of intermediate stages followed by a final update. Subclasses | ||||||
| must define `a`, `b`, and `c` as class attributes. | ||||||
|
|
||||||
| Parameters | ||||||
| ---------- | ||||||
| a : list of list of float | ||||||
| The coefficient matrix representing stage dependencies. | ||||||
| b : list of float | ||||||
| The weights for the final combination step. | ||||||
| c : list of float | ||||||
| The time shifts for each intermediate stage (often the row sums of `a`). | ||||||
|
|
||||||
| Attributes | ||||||
| ---------- | ||||||
| a : list[list[float]] | ||||||
| Butcher tableau `a` coefficients (stage coupling). | ||||||
| b : list[float] | ||||||
| Butcher tableau `b` coefficients (weights for combining stages). | ||||||
| c : list[float] | ||||||
| Butcher tableau `c` coefficients (stage time positions). | ||||||
| s : int | ||||||
| Number of stages in the RK method, inferred from `b`. | ||||||
| """ | ||||||
|
|
||||||
| def __init__(self, *args): | ||||||
| self.a = getattr(self, 'a', None) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems strangely formed. I would have
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I worked on this. Can you confirm if this is what you meant? |
||||||
| self.b = getattr(self, 'b', None) | ||||||
| self.c = getattr(self, 'c', None) | ||||||
| self.s = len(self.b) if self.b is not None else 0 # Number of stages | ||||||
|
|
||||||
| self._validate() | ||||||
|
|
||||||
| def _validate(self): | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This error handling should happen at the point where these values are first supplied.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I worked on this too. Can you confirm if this is what you meant? |
||||||
| assert self.a is not None and self.b is not None and self.c is not None, \ | ||||||
| f"RK subclass must define class attributes a, b, and c" | ||||||
| assert len(self.a) == self.s, f"'a'={a} must have {self.s} rows" | ||||||
| assert len(self.c) == self.s, f"'c'={c} must have {self.s} elements" | ||||||
|
|
||||||
| def _evaluate(self, eq_num=0): | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should have
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||||||
| """ | ||||||
| Generate the stage-wise equations for a Runge-Kutta time integration method. | ||||||
|
|
||||||
| This method takes a single equation of the form `Eq(u.forward, rhs)` and | ||||||
| expands it into a sequence of intermediate stage evaluations and a final | ||||||
| update equation according to the Runge-Kutta coefficients `a`, `b`, and `c`. | ||||||
|
|
||||||
| Parameters | ||||||
| ---------- | ||||||
| eq_num : int, optional | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You shouldn't need counters like this. Use the
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried to incorporate this suggestion, but I could use a bit of help. I think I should create a variable like |
||||||
| An identifier index used to uniquely name the intermediate stage variables | ||||||
| (`k{eq_num}i`) in case of multiple equations being expanded. | ||||||
|
|
||||||
| Returns | ||||||
| ------- | ||||||
| list of Eq | ||||||
| A list of SymPy Eq objects representing: | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nitpick: they will be Devito
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||||||
| - `s` stage equations of the form `k_i = rhs evaluated at intermediate state` | ||||||
| - 1 final update equation of the form `u.forward = u + dt * sum(b_i * k_i)` | ||||||
| """ | ||||||
| base_eq=self.eq | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With the
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||||||
| u = base_eq.lhs.function | ||||||
| rhs = base_eq.rhs | ||||||
| grid = u.grid | ||||||
| t = grid.time_dim | ||||||
| dt = t.spacing | ||||||
|
|
||||||
| # Create temporary Functions to hold each stage | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nitpick: these are
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. right! |
||||||
| # k = [Array(name=f'k{eq_num}{i}', dimensions=grid.shape, grid=grid, dtype=u.dtype) for i in range(self.s)] # Trying Array | ||||||
| k = [Function(name=f'k{eq_num}{i}', grid=grid, space_order=u.space_order, dtype=u.dtype) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are internal to Devito, should not appear in operator arguments, and should not be touched by the user, and so should use
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The thing is, when I try using |
||||||
| for i in range(self.s)] | ||||||
|
|
||||||
| stage_eqs = [] | ||||||
|
|
||||||
| # Build each stage | ||||||
| for i in range(self.s): | ||||||
| u_temp = u + dt * sum(aij * kj for aij, kj in zip(self.a[i][:i], k[:i])) | ||||||
| t_shift = t + self.c[i] * dt | ||||||
|
|
||||||
| # Evaluate RHS at intermediate value | ||||||
| stage_rhs = uxreplace(rhs, {u: u_temp, t: t_shift}) | ||||||
| stage_eqs.append(Eq(k[i], stage_rhs)) | ||||||
|
|
||||||
| # Final update: u.forward = u + dt * sum(b_i * k_i) | ||||||
| u_next = u + dt * sum(bi * ki for bi, ki in zip(self.b, k)) | ||||||
| stage_eqs.append(Eq(u.forward, u_next)) | ||||||
|
|
||||||
| return stage_eqs | ||||||
|
|
||||||
|
|
||||||
| class RK44(RK): | ||||||
| """ | ||||||
| Classic 4th-order Runge-Kutta (RK4) time integration method. | ||||||
|
|
||||||
| This class implements the classic explicit Runge-Kutta method of order 4 (RK44). | ||||||
| It uses four intermediate stages and specific Butcher coefficients to achieve | ||||||
| high accuracy while remaining explicit. | ||||||
|
|
||||||
| Attributes | ||||||
| ---------- | ||||||
| a : list[list[float]] | ||||||
| Coefficients of the `a` matrix for intermediate stage coupling. | ||||||
| b : list[float] | ||||||
| Weights for final combination. | ||||||
| c : list[float] | ||||||
| Time positions of intermediate stages. | ||||||
| """ | ||||||
| a = [[0, 0, 0, 0], | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would set these as tuples in the I would personally instead have a def __init__(self):
a = (...
b = (...
c = (...
super.__init__(a=a, b=b, c=c)
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did something like that... could you check? |
||||||
| [1/2, 0, 0, 0], | ||||||
| [0, 1/2, 0, 0], | ||||||
| [0, 0, 1, 0]] | ||||||
| b = [1/6, 1/3, 1/3, 1/6] | ||||||
| c = [0, 1/2, 1/2, 1] | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than having the the
as_tuplehere, why not have a dispatch for_lower_multistagethat dispatches on iterable types as per_concretize_subdims?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, I think..