-
Notifications
You must be signed in to change notification settings - Fork 256
dsl: Introduce abstractions for multi-stage time integrators #2599
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
1d830b8
7f087b3
214d882
d6c4d4a
78f8a0b
1c9d517
11db48b
83dfb04
d47a106
1f93a45
eea3a52
11d1429
4637ac2
ac1da7e
e9b3533
dc3dd77
1fd4a02
a0c45c1
93c6e3f
ef8d1ac
fa5acac
143d0c2
5f67b91
552fd7f
e9d2000
f7c9ea3
cf1003c
1fd480b
a875224
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,109 @@ | ||
| from devito import Function, Eq | ||
| from devito.symbolics import uxreplace | ||
| from sympy import Basic | ||
|
|
||
|
|
||
| class MultiStage(Basic): | ||
|
fernanvr marked this conversation as resolved.
Outdated
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it make sense for this to subclass
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would move all of the
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you’re right — I’ve updated the class inheritance accordingly. |
||
| def __new__(cls, eq, method): | ||
| assert isinstance(eq, Eq) | ||
| return Basic.__new__(cls, eq, method) | ||
|
|
||
| @property | ||
| def eq(self): | ||
| return self.args[0] | ||
|
|
||
| @property | ||
| def method(self): | ||
| return self.args[1] | ||
|
|
||
|
|
||
| class RK(Basic): | ||
| """ | ||
| A class representing an explicit Runge-Kutta method via its Butcher tableau. | ||
|
fernanvr marked this conversation as resolved.
Outdated
|
||
|
|
||
| Parameters | ||
| ---------- | ||
| a : list[list[float]] | ||
|
fernanvr marked this conversation as resolved.
Outdated
|
||
| Lower-triangular coefficient matrix (stage dependencies). | ||
| b : list[float] | ||
| Weights for the final combination step. | ||
| c : list[float] | ||
| Weights for the stages time step. | ||
| """ | ||
|
|
||
| def __init__(self, a, b, c): | ||
| self.a = a | ||
| self.b = b | ||
| self.c = c | ||
| self.s = len(b) # number of stages | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nitpick: two spaces before inline comment, and should start with a capital letter
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks and done |
||
|
|
||
| self._validate() | ||
|
|
||
| def _validate(self): | ||
| assert len(self.a) == self.s, "'a' must have s rows" | ||
|
fernanvr marked this conversation as resolved.
Outdated
|
||
| for i, row in enumerate(self.a): | ||
| assert len(row) == i, f"Row {i} in 'a' must have {i} entries for explicit RK" | ||
|
|
||
| def expand_stages(self, base_eq, eq_num=0): | ||
|
fernanvr marked this conversation as resolved.
Outdated
|
||
| """ | ||
| Expand a single Eq into a list of stage-wise Eqs for this RK method. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| base_eq : Eq | ||
| The equation Eq(u.forward, rhs) to be expanded into RK stages. | ||
| eq_number : integer, optional | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nitpick: double-check docstrings, there are some inconsistencies and typos in here
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I’ve resolved it, but let me know if anything still looks off. |
||
| The equation number to idetify the k_i's stages | ||
|
|
||
| Returns | ||
| ------- | ||
| list of Eq | ||
| Stage-wise equations: [k0=..., k1=..., ..., u.forward=...] | ||
| """ | ||
| u = base_eq.lhs.function | ||
| rhs = base_eq.rhs | ||
| grid = u.grid | ||
| dt = grid.stepping_dim.spacing | ||
| t = grid.time_dim | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. t = grid.time_dim
dt = t.spacingwould be a little neater
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. appreciated and done |
||
|
|
||
| # Create temporary Functions to hold each stage | ||
| k = [Function(name=f'k{eq_num}{i}', grid=grid, space_order=u.space_order, dtype=u.dtype) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This needs to use Given that these
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes I think these need to be
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This part was tricky, I still haven’t figured it out. I left a commented line where I tried defining the |
||
| for i in range(self.s)] | ||
|
|
||
| stage_eqs = [] | ||
|
|
||
| # Build each stage | ||
| for i in range(self.s): | ||
| u_temp = u | ||
| for j in range(i): | ||
| if self.a[i][j] != 0: | ||
| u_temp += self.a[i][j] * dt * k[j] | ||
| t_shift = t + self.c[i] * dt | ||
|
|
||
| # Evaluate RHS at intermediate value | ||
| stage_rhs = uxreplace(rhs, {u: u_temp, t: t_shift}) | ||
| stage_eqs.append(Eq(k[i], stage_rhs)) | ||
|
|
||
| # Final update: u.forward = u + dt * sum(bᵢ * kᵢ) | ||
| u_next = u | ||
| for i in range(self.s): | ||
| u_next += self.b[i] * dt * k[i] | ||
|
fernanvr marked this conversation as resolved.
Outdated
|
||
| stage_eqs.append(Eq(u.forward, u_next)) | ||
|
|
||
| return stage_eqs | ||
|
|
||
| # ---- Named methods for convenience ---- | ||
| @classmethod | ||
| def RK44(cls): | ||
| """Classical Runge-Kutta of 4 stages and 4th order""" | ||
| a = [ | ||
| [], | ||
| [1 / 2], | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nitpick:
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. appreciated and done |
||
| [0, 1 / 2], | ||
| [0, 0, 1] | ||
| ] | ||
| b = [1 / 6, 1 / 3, 1 / 3, 1 / 6] | ||
| c = [0, 1 / 2, 1 / 2, 1] | ||
| return cls(a, b, c) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,6 +36,7 @@ | |
| disk_layer) | ||
| from devito.types.dimension import Thickness | ||
|
|
||
| from devito.operator.new_classes import MultiStage | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please run the linter (
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
| __all__ = ['Operator'] | ||
|
|
||
|
|
@@ -184,8 +185,9 @@ def _sanitize_exprs(cls, expressions, **kwargs): | |
| expressions = as_tuple(expressions) | ||
|
|
||
| for i in expressions: | ||
| if not isinstance(i, Evaluable): | ||
| raise CompilationError(f"`{i!s}` is not an Evaluable object; " | ||
| i_check = i.eq if isinstance(i, MultiStage) else i | ||
|
fernanvr marked this conversation as resolved.
Outdated
|
||
| if not isinstance(i_check, Evaluable): | ||
| raise CompilationError(f"`{i_check!s}` is not an Evaluable object; " | ||
| "check your equation again") | ||
|
|
||
| return expressions | ||
|
|
@@ -271,6 +273,9 @@ def _lower(cls, expressions, **kwargs): | |
| # expression for which a partial or complete lowering is desired | ||
| kwargs['rcompile'] = cls._rcompile_wrapper(**kwargs) | ||
|
|
||
| # [MultiStage] -> [Eqs] | ||
| expressions = cls._lower_multistage(expressions, **kwargs) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should be called inside
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I changed it to be called within |
||
|
|
||
| # [Eq] -> [LoweredEq] | ||
| expressions = cls._lower_exprs(expressions, **kwargs) | ||
|
|
||
|
|
@@ -314,6 +319,27 @@ def _specialize_exprs(cls, expressions, **kwargs): | |
| """ | ||
| return expressions | ||
|
|
||
| @classmethod | ||
| @timed_pass(name='lowering.MultiStages') | ||
| def _lower_multistage(cls, expressions, **kwargs): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. True
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
| """ | ||
| Separating the multi-stage time-integrator scheme in stages: | ||
|
|
||
| * Check if the time-integrator is Multistage; | ||
| * Creating the stages of the method. | ||
| """ | ||
|
|
||
| lowered = [] | ||
| for i, eq in enumerate(as_tuple(expressions)): | ||
| if isinstance(eq, MultiStage): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would using
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe it's done now. |
||
| time_int = eq.method | ||
| stage_eqs = time_int.expand_stages(eq.eq, eq_num=i) | ||
| lowered.extend(stage_eqs) | ||
| else: | ||
| lowered.append(eq) | ||
|
|
||
| return lowered | ||
|
|
||
| @classmethod | ||
| @timed_pass(name='lowering.Expressions') | ||
| def _lower_exprs(cls, expressions, **kwargs): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| import numpy as np | ||
| import sympy as sym | ||
| import matplotlib.pyplot as plt | ||
|
|
||
| import devito as dv | ||
|
fernanvr marked this conversation as resolved.
Outdated
|
||
| from examples.seismic import Receiver, TimeAxis | ||
|
|
||
| from devito.operator.new_classes import RK, MultiStage | ||
| # Set logging level for debugging | ||
| dv.configuration['log-level'] = 'DEBUG' | ||
|
|
||
| # Parameters | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wants to use It's worth noting that tests (especially the very elementary ones) can be physically nonsensical, in order to check that some aspect of the API or compiler functions as intended. For example, it would be useful to test that a range of Butcher tableaus get correctly assembled into the corresponding timestepping schemes, even if those timestepping schemes themselves are meaningless. Examples of simplifications used for unit tests include using 1D where possible and using trivial
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Other tests that we include for essentially every new class include ensuring that it can be pickled (see test_pickle.py) and that it can be rebuilt correctly using its
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I’ve been working in that direction, but I think it’s still a bit rough |
||
| space_order = 2 | ||
| fd_order = 2 | ||
| extent = (1000, 1000) | ||
| shape = (201, 201) | ||
| origin = (0, 0) | ||
|
|
||
| # Grid setup | ||
| grid = dv.Grid(origin=origin, extent=extent, shape=shape, dtype=np.float64) | ||
| x, y = grid.dimensions | ||
| dt = grid.stepping_dim.spacing | ||
| t = grid.time_dim | ||
| dx = extent[0] / (shape[0] - 1) | ||
|
|
||
| # Medium velocity model | ||
| vel = dv.Function(name="vel", grid=grid, space_order=space_order, dtype=np.float64) | ||
| vel.data[:] = 1.0 | ||
| vel.data[150:, :] = 1.3 | ||
|
|
||
| # Define wavefield unknowns: u (displacement) and v (velocity) | ||
| fun_labels = ['u', 'v'] | ||
| U = [dv.TimeFunction(name=name, grid=grid, space_order=space_order, | ||
| time_order=1, dtype=np.float64) for name in fun_labels] | ||
|
|
||
| # Time axis | ||
| t0, tn = 0.0, 500.0 | ||
| dt0 = np.max(vel.data) / dx**2 | ||
| nt = int((tn - t0) / dt0) | ||
| dt0 = tn / nt | ||
| time_range = TimeAxis(start=t0, stop=tn, num=nt + 1) | ||
|
|
||
| # Receiver setup | ||
| rec = Receiver(name='rec', grid=grid, npoint=3, time_range=time_range) | ||
| rec.coordinates.data[:, 0] = np.linspace(0, 1, 3) | ||
| rec.coordinates.data[:, 1] = 0.5 | ||
| rec = rec.interpolate(expr=U[0].forward) | ||
|
|
||
| # Source definition | ||
| src_spatial = dv.Function(name="src_spat", grid=grid, space_order=space_order, dtype=np.float64) | ||
| src_spatial.data[100, 100] = 1 / dx**2 | ||
|
|
||
| f0 = 0.01 | ||
| src_temporal = (1 - 2 * (np.pi * f0 * (t * dt - 1/f0))**2) * sym.exp(-(np.pi * f0 * (t * dt - 1/f0))**2) | ||
|
|
||
| # PDE system (2D acoustic) | ||
| system_eqs = [U[1], | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe a simple heat equation would be a good example/test?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I incorporated for the tests |
||
| (dv.Derivative(U[0], (x, 2), fd_order=fd_order) + | ||
| dv.Derivative(U[0], (y, 2), fd_order=fd_order) + | ||
| src_spatial * src_temporal) * vel**2] | ||
|
|
||
| # Time integration scheme | ||
| rk = RK.RK44() | ||
|
|
||
| # MultiStage object | ||
| pdes = [MultiStage(dv.Eq(U[i], system_eqs[i]), rk) for i in range(2)] | ||
|
|
||
| # Construct and run operator | ||
| op = dv.Operator(pdes + [rec], subs=grid.spacing_map) | ||
| op(dt=dt0, time=nt) | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the generated C code, I noticed the following: Is this intended? It doesn't seem optimal since they are all assigned the same value |
||
| # Plot final wavefield | ||
| plt.imshow(U[0].data[1, :], cmap="seismic") | ||
| plt.colorbar(label="Amplitude") | ||
| plt.title("Wavefield snapshot (t = final)") | ||
| plt.xlabel("x") | ||
| plt.ylabel("y") | ||
| plt.tight_layout() | ||
| plt.show() | ||
Uh oh!
There was an error while loading. Please reload this page.