Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
1d830b8
implementation of multi-stage time integrators
fernanvr May 5, 2025
7f087b3
Merge remote-tracking branch 'upstream/main' into multi-stage-time-in…
fernanvr Jun 13, 2025
214d882
Return of first PR comments
fernanvr Jun 13, 2025
d6c4d4a
Return of first PR comments
fernanvr Jun 13, 2025
78f8a0b
2nd PR revision
fernanvr Jun 23, 2025
1c9d517
2nd PR revision
fernanvr Jun 23, 2025
11db48b
2rd PR, updating tests and suggestions of 2nd PR revision
fernanvr Jun 25, 2025
83dfb04
3rd PR, updating tests and suggestions of 2nd PR revision
fernanvr Jun 25, 2025
d47a106
4th PR revision, code refining and improving tests
fernanvr Jun 25, 2025
1f93a45
4th PR revision, code refining and improving tests
fernanvr Jun 25, 2025
eea3a52
5th PR revision, one suggestion from EdC and improving tests
fernanvr Jun 26, 2025
11d1429
including two more Runge-Kutta methods and improving tests: checking …
fernanvr Jul 1, 2025
4637ac2
changes to consider coupled Multistage equations
fernanvr Jul 16, 2025
ac1da7e
Improvements of the HORK_EXP
fernanvr Aug 15, 2025
e9b3533
Merge branch 'main' into multi-stage-time-integrator
fernanvr Aug 15, 2025
dc3dd77
Merge remote-tracking branch 'upstream/main' into multi-stage-time-in…
fernanvr Oct 8, 2025
1fd4a02
tuples, improved class names, extensive tests
fernanvr Oct 8, 2025
a0c45c1
improving spacing in some tests
fernanvr Oct 8, 2025
93c6e3f
Add MFE time stepping Jupyter notebook
fernanvr Oct 23, 2025
ef8d1ac
Remove MFE_time_size.ipynb notebook
fernanvr Oct 23, 2025
fa5acac
Update multistage implementation and tests
fernanvr Oct 29, 2025
143d0c2
Return of first PR comments
fernanvr Jun 13, 2025
5f67b91
updating small changes from EdC review on 26-03-2026
fernanvr Mar 26, 2026
552fd7f
Isolate multistage-related changes only
fernanvr Mar 27, 2026
e9d2000
merging two classes of Runge-Kutta
fernanvr Mar 27, 2026
f7c9ea3
Merge full multistage history while keeping clean tree
fernanvr Mar 27, 2026
cf1003c
fixing test_multistage file
fernanvr Apr 6, 2026
1fd480b
Remove devito/ir/equations/algorithms.py and devito/operator/operator…
fernanvr Apr 9, 2026
a875224
implemented suggestions of EdC and Fabio
fernanvr Apr 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions devito/ir/equations/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ def _(exprs, **kwargs):
Handle iterables of expressions.
"""
lowered = []
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you do return [_lower_multistage(expr, **kwargs) for i in exprs for expr in i]?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did something like that...

for i, expr in enumerate(exprs):
lowered.extend(_lower_multistage(expr, eq_num=i))
for expr in exprs:
lowered.extend(_lower_multistage(expr, **kwargs))
return lowered


Expand Down
6 changes: 1 addition & 5 deletions devito/operations/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,7 @@ def solve(eq, target, **kwargs):
sols_temp = sols[0]

method = kwargs.get("method", None)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is method a string here, or is it the class for the method? In the latter case, it would remove the need to have the method_registry mapper. Furthermore, it would allow you to have method.resolve(target, sols_temp) here, which is tidier

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a string. The idea is that the user provides a string to identify which time integrator to apply.

if method is not None:
method_cls = resolve_method(method)
return method_cls(target, sols_temp)._evaluate(**kwargs)
else:
return sols_temp
return sols_temp if method is None else resolve_method(method)(target, sols_temp)


def linsolve(expr, target, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def _lower_exprs(cls, expressions, **kwargs):
* Apply substitution rules;
* Shift indices for domain alignment.
"""
expressions = lower_multistage(expressions)
expressions = lower_multistage(expressions, **kwargs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should at least be called after expand = ...

and, perhaps, benefit from a more generic name such as lower_timestepping

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could also move it inside a _lower_dsl, which internally calls _specialize_dsl, just like we already do for expressions/clusters/stree/iet

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, thanks! This was part of an earlier approach that after one meeting with Devito's team was decided to be left like a plan b, so it shouldn’t actually be here. I’ll remove it from the PR to appear only the actual approach—though I agree that structuring it that way would make sense if we revisit this idea in the future and I already changed accordingly.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick correction to my previous comment: I realized this part is actually still in use in the current implementation. I’ve updated it taking your suggestions into account (ordering + naming), so it should now reflect what it was intended.


expand = kwargs['options'].get('expand', True)

Expand Down
7 changes: 4 additions & 3 deletions devito/types/multistage.py
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this file should be moved to somewhere like devito/timestepping/rungekutta.py or devito/timestepping/explicitmultistage.py that way additional timesteppers can be contributed as new files. (I'm thinking about implicit multistage, backward difference formulae etc...)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed the class to HighOrderRungeKuttaExponential. I realize the name might be confusing since this particular Runge-Kutta is explicit, but “EXP” was intended to highlight the exponential aspect. I’ve also updated the other class names based on your suggestions.

Regarding the file location, it’s currently in /types as recommended by @mloubout (see suggestion). Personally, I think both /timestepping and /types are reasonable options. Perhaps we can discuss this with @EdCaunt and @FabioLuporini to reach a consensus.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree this file doesn't belong to types/

based on https://github.com/devitocodes/devito/pull/2599/changes#r3043562368, we might add it to ir/dsl/rungekutta.py

Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,17 @@ def _evaluate(self, **kwargs):
- `s` stage equations of the form `k_i = rhs evaluated at intermediate state`
- 1 final update equation of the form `u.forward = u + dt * sum(b_i * k_i)`
"""
eq_num = kwargs['eq_num']
stage_id = kwargs.get('sregistry').make_name(prefix='k')

u = self.lhs.function
rhs = self.rhs
grid = u.grid
t = grid.time_dim
dt = t.spacing

# Create temporary Functions to hold each stage
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: these are Array now

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right!

# k = [Array(name=f'k{eq_num}{i}', dimensions=grid.shape, grid=grid, dtype=u.dtype) for i in range(self.s)] # Trying Array
k = [Function(name=f'k{eq_num}{i}', grid=grid, space_order=u.space_order, dtype=u.dtype)
# k = [Array(name=f'{stage_id}{i}', dimensions=grid.shape, grid=grid, dtype=u.dtype) for i in range(self.s)] # Trying Array
k = [Function(name=f'{stage_id}{i}', grid=grid, space_order=u.space_order, dtype=u.dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kwargs.get('sregistry').make_name(prefix='k') wants to be inside this loop to ensure that all names are unique

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

for i in range(self.s)]

stage_eqs = []
Expand Down
126 changes: 80 additions & 46 deletions tests/test_multistage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from devito import (Grid, Function, TimeFunction,
Derivative, Operator, solve, Eq)
from devito.types.multistage import resolve_method
from devito.ir.support import SymbolRegistry
from devito.ir.equations import lower_multistage


def test_multistage_solve(time_int='RK44'):
def test_multistage_object(time_int='RK44'):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using pytest.mark.parametrize here. To add, some test classes like TestLowering, TestAPI, TestRK, etc would help with organisation of this file and running specific batches of tests

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

extent = (1, 1)
shape = (3, 3)
origin = (0, 0)
Expand All @@ -25,20 +27,19 @@ def test_multistage_solve(time_int='RK44'):
# Source definition
src_spatial = Function(name="src_spat", grid=grid, space_order=2, dtype=float64)
src_spatial.data[1, 1] = 1
f0 = 0.01
src_temporal = (1 - 2 * (pi * f0 * (t * dt - 1 / f0)) ** 2) * exp(-(pi * f0 * (t * dt - 1 / f0)) ** 2)
src_temporal = (1 - 2 * (t*dt - 1)**2)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For what its worth, tests of this kind don't need to have any physical significance, so long as they produce the desired behaviour in the compiler that you are testing for. For example, you could probably omit the source terms entirely and probably the derivatives too, simply creating a multistage timestepper out of a trivial equation that adds one to the solution at each timestep or similar.

However this is still a well-made and focussed test as-is

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I think it’s important to have a test that includes derivatives and source terms. However, I agree that simpler examples should also be included. I’ve added one without those elements.


# PDE system (2D acoustic)
system_eqs_rhs = [U[1] + src_spatial * src_temporal,
Derivative(U[0], (x, 2), fd_order=2) +
Derivative(U[0], (y, 2), fd_order=2) +
src_spatial * src_temporal]
Derivative(U[0], (x, 2), fd_order=2) +
Derivative(U[0], (y, 2), fd_order=2) +
src_spatial * src_temporal]

# Time integration scheme
return [solve(system_eqs_rhs[i] - U[i], U[i], method=time_int, eq_num=i) for i in range(2)]
# Class of the time integration scheme
return [resolve_method(time_int)(U[i], system_eqs_rhs[i]) for i in range(2)]


def test_multistage_op_constructing_directly(time_int='RK44'):
def test_multistage_lower_multistage(time_int='RK44'):
extent = (1, 1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lots of boilerplate is repeated in these tests. Consider a convenience function for setting up the grid etc

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

were created two functions to reduce code repetition

shape = (3, 3)
origin = (0, 0)
Expand All @@ -57,23 +58,55 @@ def test_multistage_op_constructing_directly(time_int='RK44'):
# Source definition
src_spatial = Function(name="src_spat", grid=grid, space_order=2, dtype=float64)
src_spatial.data[1, 1] = 1
f0 = 0.01
src_temporal = (1 - 2 * (pi * f0 * (t * dt - 1 / f0)) ** 2) * exp(-(pi * f0 * (t * dt - 1 / f0)) ** 2)
src_temporal = (1 - 2 * (t*dt - 1)**2)

# PDE system (2D acoustic)
system_eqs_rhs = [U[1] + src_spatial * src_temporal,
Derivative(U[0], (x, 2), fd_order=2) +
Derivative(U[0], (y, 2), fd_order=2) +
src_spatial * src_temporal]

# Time integration scheme

# Class of the time integration scheme
pdes = [resolve_method(time_int)(U[i], system_eqs_rhs[i]) for i in range(2)]
op = Operator(pdes, subs=grid.spacing_map)
op(dt=0.01, time=1)

sregistry=SymbolRegistry()

def test_multistage_op_computing_directly(time_int='RK44'):
return lower_multistage(pdes, sregistry=sregistry)



def test_multistage_solve(time_int='RK44'):
extent = (1, 1)
shape = (3, 3)
origin = (0, 0)

# Grid setup
grid = Grid(origin=origin, extent=extent, shape=shape, dtype=float64)
x, y = grid.dimensions
dt = grid.stepping_dim.spacing
t = grid.time_dim

# Define wavefield unknowns: u (displacement) and v (velocity)
fun_labels = ['u', 'v']
U = [TimeFunction(name=name, grid=grid, space_order=2,
time_order=1, dtype=float64) for name in fun_labels]

# Source definition
src_spatial = Function(name="src_spat", grid=grid, space_order=2, dtype=float64)
src_spatial.data[1, 1] = 1
src_temporal = (1 - 2 * (t*dt - 1)**2)

# PDE system (2D acoustic)
system_eqs_rhs = [U[1] + src_spatial * src_temporal,
Derivative(U[0], (x, 2), fd_order=2) +
Derivative(U[0], (y, 2), fd_order=2) +
src_spatial * src_temporal]

# Time integration scheme
return [solve(system_eqs_rhs[i] - U[i], U[i], method=time_int) for i in range(2)]


def test_multistage_op_computing_1eq(time_int='RK44'):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since all of the tests in this file pertain to MultiStage, you can drop multistage from all function names within the file for concision

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you'r right, dropped

extent = (1, 1)
shape = (200, 200)
origin = (0, 0)
Expand All @@ -85,40 +118,44 @@ def test_multistage_op_computing_directly(time_int='RK44'):
t = grid.time_dim

# Define wavefield unknowns: u (displacement) and v (velocity)
u_multi_stage = TimeFunction(name='u_multi_stage', grid=grid, space_order=2, time_order=1, dtype=float64)
fun_labels = ['u_multi_stage', 'v_multi_stage']
U_multi_stage = [TimeFunction(name=name, grid=grid, space_order=2,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use of a capital U here makes this look like a class, consider renaming

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed for U_multi_stage and for U.

time_order=1, dtype=float64) for name in fun_labels]

# Source definition
src_spatial = Function(name="src_spat", grid=grid, space_order=2, dtype=float64)
src_spatial.data[1, 1] = 1
f0 = 0.01
src_temporal = (1 - 2 * (pi * f0 * (t * dt - 1 / f0)) ** 2) * exp(-(pi * f0 * (t * dt - 1 / f0)) ** 2)
src_temporal = (1 - 2 * (t*dt - 1)**2)

# PDE (2D heat eq.)
eq_rhs = (Derivative(u_multi_stage, (x, 2), fd_order=2) + Derivative(u_multi_stage, (y, 2), fd_order=2) +
src_spatial * src_temporal)
# PDE system
system_eqs_rhs = [U_multi_stage[1] + src_spatial * src_temporal,
Derivative(U_multi_stage[0], (x, 2), fd_order=2) +
Derivative(U_multi_stage[0], (y, 2), fd_order=2) +
src_spatial * src_temporal]

# Time integration scheme
pde = [MultiStage(eq_rhs, u_multi_stage, method=time_int)]
op = Operator(pde, subs=grid.spacing_map)
pdes = [resolve_method(time_int)(U_multi_stage[i], system_eqs_rhs[i]) for i in range(2)]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still not a fan of this resolve_method("method_name") API. I think MethodClass(lhs, rhs) is far less ambiguous

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a change on this, not quite as your suggestion because I think it is not friendly asking to the user to import the specific class of the time integration.

op = Operator(pdes, subs=grid.spacing_map)
op(dt=0.01, time=1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should assert a norm or similar

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand your point. But the idea of this test if only to check if op.apply() executes for multistage objects. Is not about the convergence. Do you think it is unnecessary do that?


# Solving now using Devito's standard time solver
u = TimeFunction(name='u', grid=grid, space_order=2, time_order=1, dtype=float64)
eq_rhs = (Derivative(u, (x, 2), fd_order=2) + Derivative(u, (y, 2), fd_order=2) +
src_spatial * src_temporal)
# Define wavefield unknowns: u (displacement) and v (velocity)
fun_labels = ['u', 'v']
U = [TimeFunction(name=name, grid=grid, space_order=2,
time_order=1, dtype=float64) for name in fun_labels]
system_eqs_rhs = [U[1] + src_spatial * src_temporal,
Derivative(U[0], (x, 2), fd_order=2) +
Derivative(U[0], (y, 2), fd_order=2) +
src_spatial * src_temporal]

# Time integration scheme
pde = Eq(u, solve(eq_rhs - u, u))
op = Operator(pde, subs=grid.spacing_map)
pdes = [Eq(U[i], system_eqs_rhs[i]) for i in range(2)]
op = Operator(pdes, subs=grid.spacing_map)
op(dt=0.01, time=1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should also assert something

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


return max(abs(u_multi_stage.data[0, :] - u.data[0, :]))
return max(abs(U_multi_stage[0].data[0, :] - U[0].data[0, :]))

# test_multistage_op_constructing_directly()

# test_multistage_op_computing_directly()

def test_multistage_op_solve_computing(time_int='RK44'):
def test_multistage_op_computing_directly(time_int='RK44'):
extent = (1, 1)
shape = (200, 200)
origin = (0, 0)
Expand All @@ -129,22 +166,21 @@ def test_multistage_op_solve_computing(time_int='RK44'):
dt = grid.stepping_dim.spacing
t = grid.time_dim

# Define unknown for the 'time_int' method: u (heat)
u_time_int = TimeFunction(name='u', grid=grid, space_order=2, time_order=1, dtype=float64)
# Define wavefield unknowns: u (displacement) and v (velocity)
u_multi_stage = TimeFunction(name='u_multi_stage', grid=grid, space_order=2, time_order=1, dtype=float64)

# Source definition
src_spatial = Function(name="src_spat", grid=grid, space_order=2, dtype=float64)
src_spatial.data[1, 1] = 1
f0 = 0.01
src_temporal = (1 - 2 * (pi * f0 * (t * dt - 1 / f0)) ** 2) * exp(-(pi * f0 * (t * dt - 1 / f0)) ** 2)
src_temporal = (1 - 2 * (t*dt - 1)**2)

# PDE (2D heat eq.)
eq_rhs = (Derivative(u_time_int, (x, 2), fd_order=2) + Derivative(u_time_int, (y, 2), fd_order=2) +
src_spatial * src_temporal)
eq_rhs = (Derivative(u_multi_stage, (x, 2), fd_order=2) + Derivative(u_multi_stage, (y, 2), fd_order=2) +
src_spatial * src_temporal)

# Time integration scheme
pde = solve(eq_rhs - u_time_int, u_time_int, method=time_int)
op=Operator(pde, subs=grid.spacing_map)
pde = [resolve_method(time_int)(eq_rhs, u_multi_stage)]
op = Operator(pde, subs=grid.spacing_map)
op(dt=0.01, time=1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again should assert a norm. Can also be consolidated with the previous test via parameterisation

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same, the idea is to check if op.apply() executes without an error...


# Solving now using Devito's standard time solver
Expand All @@ -157,6 +193,4 @@ def test_multistage_op_solve_computing(time_int='RK44'):
op = Operator(pde, subs=grid.spacing_map)
op(dt=0.01, time=1)

Comment thread
EdCaunt marked this conversation as resolved.
return max(abs(u_time_int.data[0,:]-u.data[0,:]))

# test_multistage_op_solve_computing()
return max(abs(u_multi_stage.data[0, :] - u.data[0, :]))