-
Notifications
You must be signed in to change notification settings - Fork 255
dsl: Introduce abstractions for multi-stage time integrators #2599
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
1d830b8
7f087b3
214d882
d6c4d4a
78f8a0b
1c9d517
11db48b
83dfb04
d47a106
1f93a45
eea3a52
11d1429
4637ac2
ac1da7e
e9b3533
dc3dd77
1fd4a02
a0c45c1
93c6e3f
ef8d1ac
fa5acac
143d0c2
5f67b91
552fd7f
e9d2000
f7c9ea3
cf1003c
1fd480b
a875224
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -63,11 +63,7 @@ def solve(eq, target, **kwargs): | |
| sols_temp = sols[0] | ||
|
|
||
| method = kwargs.get("method", None) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is method a string here, or is it the class for the method? In the latter case, it would remove the need to have the
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a string. The idea is that the user provides a string to identify which time integrator to apply. |
||
| if method is not None: | ||
| method_cls = resolve_method(method) | ||
| return method_cls(target, sols_temp)._evaluate(**kwargs) | ||
| else: | ||
| return sols_temp | ||
| return sols_temp if method is None else resolve_method(method)(target, sols_temp) | ||
|
|
||
|
|
||
| def linsolve(expr, target, **kwargs): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -326,7 +326,7 @@ def _lower_exprs(cls, expressions, **kwargs): | |
| * Apply substitution rules; | ||
| * Shift indices for domain alignment. | ||
| """ | ||
| expressions = lower_multistage(expressions) | ||
| expressions = lower_multistage(expressions, **kwargs) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should at least be called after and, perhaps, benefit from a more generic name such as
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we could also move it inside a
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That makes sense, thanks! This was part of an earlier approach that after one meeting with Devito's team was decided to be left like a plan b, so it shouldn’t actually be here. I’ll remove it from the PR to appear only the actual approach—though I agree that structuring it that way would make sense if we revisit this idea in the future and I already changed accordingly.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Quick correction to my previous comment: I realized this part is actually still in use in the current implementation. I’ve updated it taking your suggestions into account (ordering + naming), so it should now reflect what it was intended. |
||
|
|
||
| expand = kwargs['options'].get('expand', True) | ||
|
|
||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this file should be moved to somewhere like
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I renamed the class to Regarding the file location, it’s currently in
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree this file doesn't belong to based on https://github.com/devitocodes/devito/pull/2599/changes#r3043562368, we might add it to |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -127,16 +127,17 @@ def _evaluate(self, **kwargs): | |
| - `s` stage equations of the form `k_i = rhs evaluated at intermediate state` | ||
| - 1 final update equation of the form `u.forward = u + dt * sum(b_i * k_i)` | ||
| """ | ||
| eq_num = kwargs['eq_num'] | ||
| stage_id = kwargs.get('sregistry').make_name(prefix='k') | ||
|
|
||
| u = self.lhs.function | ||
| rhs = self.rhs | ||
| grid = u.grid | ||
| t = grid.time_dim | ||
| dt = t.spacing | ||
|
|
||
| # Create temporary Functions to hold each stage | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nitpick: these are
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. right! |
||
| # k = [Array(name=f'k{eq_num}{i}', dimensions=grid.shape, grid=grid, dtype=u.dtype) for i in range(self.s)] # Trying Array | ||
| k = [Function(name=f'k{eq_num}{i}', grid=grid, space_order=u.space_order, dtype=u.dtype) | ||
| # k = [Array(name=f'{stage_id}{i}', dimensions=grid.shape, grid=grid, dtype=u.dtype) for i in range(self.s)] # Trying Array | ||
| k = [Function(name=f'{stage_id}{i}', grid=grid, space_order=u.space_order, dtype=u.dtype) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
| for i in range(self.s)] | ||
|
|
||
| stage_eqs = [] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,9 +4,11 @@ | |
| from devito import (Grid, Function, TimeFunction, | ||
| Derivative, Operator, solve, Eq) | ||
| from devito.types.multistage import resolve_method | ||
| from devito.ir.support import SymbolRegistry | ||
| from devito.ir.equations import lower_multistage | ||
|
|
||
|
|
||
| def test_multistage_solve(time_int='RK44'): | ||
| def test_multistage_object(time_int='RK44'): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider using
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
| extent = (1, 1) | ||
| shape = (3, 3) | ||
| origin = (0, 0) | ||
|
|
@@ -25,20 +27,19 @@ def test_multistage_solve(time_int='RK44'): | |
| # Source definition | ||
| src_spatial = Function(name="src_spat", grid=grid, space_order=2, dtype=float64) | ||
| src_spatial.data[1, 1] = 1 | ||
| f0 = 0.01 | ||
| src_temporal = (1 - 2 * (pi * f0 * (t * dt - 1 / f0)) ** 2) * exp(-(pi * f0 * (t * dt - 1 / f0)) ** 2) | ||
| src_temporal = (1 - 2 * (t*dt - 1)**2) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For what its worth, tests of this kind don't need to have any physical significance, so long as they produce the desired behaviour in the compiler that you are testing for. For example, you could probably omit the source terms entirely and probably the derivatives too, simply creating a multistage timestepper out of a trivial equation that adds one to the solution at each timestep or similar. However this is still a well-made and focussed test as-is
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! I think it’s important to have a test that includes derivatives and source terms. However, I agree that simpler examples should also be included. I’ve added one without those elements. |
||
|
|
||
| # PDE system (2D acoustic) | ||
| system_eqs_rhs = [U[1] + src_spatial * src_temporal, | ||
| Derivative(U[0], (x, 2), fd_order=2) + | ||
| Derivative(U[0], (y, 2), fd_order=2) + | ||
| src_spatial * src_temporal] | ||
| Derivative(U[0], (x, 2), fd_order=2) + | ||
| Derivative(U[0], (y, 2), fd_order=2) + | ||
| src_spatial * src_temporal] | ||
|
|
||
| # Time integration scheme | ||
| return [solve(system_eqs_rhs[i] - U[i], U[i], method=time_int, eq_num=i) for i in range(2)] | ||
| # Class of the time integration scheme | ||
| return [resolve_method(time_int)(U[i], system_eqs_rhs[i]) for i in range(2)] | ||
|
|
||
|
|
||
| def test_multistage_op_constructing_directly(time_int='RK44'): | ||
| def test_multistage_lower_multistage(time_int='RK44'): | ||
| extent = (1, 1) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lots of boilerplate is repeated in these tests. Consider a convenience function for setting up the grid etc
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. were created two functions to reduce code repetition |
||
| shape = (3, 3) | ||
| origin = (0, 0) | ||
|
|
@@ -57,23 +58,55 @@ def test_multistage_op_constructing_directly(time_int='RK44'): | |
| # Source definition | ||
| src_spatial = Function(name="src_spat", grid=grid, space_order=2, dtype=float64) | ||
| src_spatial.data[1, 1] = 1 | ||
| f0 = 0.01 | ||
| src_temporal = (1 - 2 * (pi * f0 * (t * dt - 1 / f0)) ** 2) * exp(-(pi * f0 * (t * dt - 1 / f0)) ** 2) | ||
| src_temporal = (1 - 2 * (t*dt - 1)**2) | ||
|
|
||
| # PDE system (2D acoustic) | ||
| system_eqs_rhs = [U[1] + src_spatial * src_temporal, | ||
| Derivative(U[0], (x, 2), fd_order=2) + | ||
| Derivative(U[0], (y, 2), fd_order=2) + | ||
| src_spatial * src_temporal] | ||
|
|
||
| # Time integration scheme | ||
|
|
||
| # Class of the time integration scheme | ||
| pdes = [resolve_method(time_int)(U[i], system_eqs_rhs[i]) for i in range(2)] | ||
| op = Operator(pdes, subs=grid.spacing_map) | ||
| op(dt=0.01, time=1) | ||
|
|
||
| sregistry=SymbolRegistry() | ||
|
|
||
| def test_multistage_op_computing_directly(time_int='RK44'): | ||
| return lower_multistage(pdes, sregistry=sregistry) | ||
|
|
||
|
|
||
|
|
||
| def test_multistage_solve(time_int='RK44'): | ||
| extent = (1, 1) | ||
| shape = (3, 3) | ||
| origin = (0, 0) | ||
|
|
||
| # Grid setup | ||
| grid = Grid(origin=origin, extent=extent, shape=shape, dtype=float64) | ||
| x, y = grid.dimensions | ||
| dt = grid.stepping_dim.spacing | ||
| t = grid.time_dim | ||
|
|
||
| # Define wavefield unknowns: u (displacement) and v (velocity) | ||
| fun_labels = ['u', 'v'] | ||
| U = [TimeFunction(name=name, grid=grid, space_order=2, | ||
| time_order=1, dtype=float64) for name in fun_labels] | ||
|
|
||
| # Source definition | ||
| src_spatial = Function(name="src_spat", grid=grid, space_order=2, dtype=float64) | ||
| src_spatial.data[1, 1] = 1 | ||
| src_temporal = (1 - 2 * (t*dt - 1)**2) | ||
|
|
||
| # PDE system (2D acoustic) | ||
| system_eqs_rhs = [U[1] + src_spatial * src_temporal, | ||
| Derivative(U[0], (x, 2), fd_order=2) + | ||
| Derivative(U[0], (y, 2), fd_order=2) + | ||
| src_spatial * src_temporal] | ||
|
|
||
| # Time integration scheme | ||
| return [solve(system_eqs_rhs[i] - U[i], U[i], method=time_int) for i in range(2)] | ||
|
|
||
|
|
||
| def test_multistage_op_computing_1eq(time_int='RK44'): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since all of the tests in this file pertain to
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you'r right, dropped |
||
| extent = (1, 1) | ||
| shape = (200, 200) | ||
| origin = (0, 0) | ||
|
|
@@ -85,40 +118,44 @@ def test_multistage_op_computing_directly(time_int='RK44'): | |
| t = grid.time_dim | ||
|
|
||
| # Define wavefield unknowns: u (displacement) and v (velocity) | ||
| u_multi_stage = TimeFunction(name='u_multi_stage', grid=grid, space_order=2, time_order=1, dtype=float64) | ||
| fun_labels = ['u_multi_stage', 'v_multi_stage'] | ||
| U_multi_stage = [TimeFunction(name=name, grid=grid, space_order=2, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use of a capital
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changed for |
||
| time_order=1, dtype=float64) for name in fun_labels] | ||
|
|
||
| # Source definition | ||
| src_spatial = Function(name="src_spat", grid=grid, space_order=2, dtype=float64) | ||
| src_spatial.data[1, 1] = 1 | ||
| f0 = 0.01 | ||
| src_temporal = (1 - 2 * (pi * f0 * (t * dt - 1 / f0)) ** 2) * exp(-(pi * f0 * (t * dt - 1 / f0)) ** 2) | ||
| src_temporal = (1 - 2 * (t*dt - 1)**2) | ||
|
|
||
| # PDE (2D heat eq.) | ||
| eq_rhs = (Derivative(u_multi_stage, (x, 2), fd_order=2) + Derivative(u_multi_stage, (y, 2), fd_order=2) + | ||
| src_spatial * src_temporal) | ||
| # PDE system | ||
| system_eqs_rhs = [U_multi_stage[1] + src_spatial * src_temporal, | ||
| Derivative(U_multi_stage[0], (x, 2), fd_order=2) + | ||
| Derivative(U_multi_stage[0], (y, 2), fd_order=2) + | ||
| src_spatial * src_temporal] | ||
|
|
||
| # Time integration scheme | ||
| pde = [MultiStage(eq_rhs, u_multi_stage, method=time_int)] | ||
| op = Operator(pde, subs=grid.spacing_map) | ||
| pdes = [resolve_method(time_int)(U_multi_stage[i], system_eqs_rhs[i]) for i in range(2)] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm still not a fan of this
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did a change on this, not quite as your suggestion because I think it is not friendly asking to the user to import the specific class of the time integration. |
||
| op = Operator(pdes, subs=grid.spacing_map) | ||
| op(dt=0.01, time=1) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should assert a norm or similar
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand your point. But the idea of this test if only to check if op.apply() executes for multistage objects. Is not about the convergence. Do you think it is unnecessary do that? |
||
|
|
||
| # Solving now using Devito's standard time solver | ||
| u = TimeFunction(name='u', grid=grid, space_order=2, time_order=1, dtype=float64) | ||
| eq_rhs = (Derivative(u, (x, 2), fd_order=2) + Derivative(u, (y, 2), fd_order=2) + | ||
| src_spatial * src_temporal) | ||
| # Define wavefield unknowns: u (displacement) and v (velocity) | ||
| fun_labels = ['u', 'v'] | ||
| U = [TimeFunction(name=name, grid=grid, space_order=2, | ||
| time_order=1, dtype=float64) for name in fun_labels] | ||
| system_eqs_rhs = [U[1] + src_spatial * src_temporal, | ||
| Derivative(U[0], (x, 2), fd_order=2) + | ||
| Derivative(U[0], (y, 2), fd_order=2) + | ||
| src_spatial * src_temporal] | ||
|
|
||
| # Time integration scheme | ||
| pde = Eq(u, solve(eq_rhs - u, u)) | ||
| op = Operator(pde, subs=grid.spacing_map) | ||
| pdes = [Eq(U[i], system_eqs_rhs[i]) for i in range(2)] | ||
| op = Operator(pdes, subs=grid.spacing_map) | ||
| op(dt=0.01, time=1) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should also assert something
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
|
||
| return max(abs(u_multi_stage.data[0, :] - u.data[0, :])) | ||
| return max(abs(U_multi_stage[0].data[0, :] - U[0].data[0, :])) | ||
|
|
||
| # test_multistage_op_constructing_directly() | ||
|
|
||
| # test_multistage_op_computing_directly() | ||
|
|
||
| def test_multistage_op_solve_computing(time_int='RK44'): | ||
| def test_multistage_op_computing_directly(time_int='RK44'): | ||
| extent = (1, 1) | ||
| shape = (200, 200) | ||
| origin = (0, 0) | ||
|
|
@@ -129,22 +166,21 @@ def test_multistage_op_solve_computing(time_int='RK44'): | |
| dt = grid.stepping_dim.spacing | ||
| t = grid.time_dim | ||
|
|
||
| # Define unknown for the 'time_int' method: u (heat) | ||
| u_time_int = TimeFunction(name='u', grid=grid, space_order=2, time_order=1, dtype=float64) | ||
| # Define wavefield unknowns: u (displacement) and v (velocity) | ||
| u_multi_stage = TimeFunction(name='u_multi_stage', grid=grid, space_order=2, time_order=1, dtype=float64) | ||
|
|
||
| # Source definition | ||
| src_spatial = Function(name="src_spat", grid=grid, space_order=2, dtype=float64) | ||
| src_spatial.data[1, 1] = 1 | ||
| f0 = 0.01 | ||
| src_temporal = (1 - 2 * (pi * f0 * (t * dt - 1 / f0)) ** 2) * exp(-(pi * f0 * (t * dt - 1 / f0)) ** 2) | ||
| src_temporal = (1 - 2 * (t*dt - 1)**2) | ||
|
|
||
| # PDE (2D heat eq.) | ||
| eq_rhs = (Derivative(u_time_int, (x, 2), fd_order=2) + Derivative(u_time_int, (y, 2), fd_order=2) + | ||
| src_spatial * src_temporal) | ||
| eq_rhs = (Derivative(u_multi_stage, (x, 2), fd_order=2) + Derivative(u_multi_stage, (y, 2), fd_order=2) + | ||
| src_spatial * src_temporal) | ||
|
|
||
| # Time integration scheme | ||
| pde = solve(eq_rhs - u_time_int, u_time_int, method=time_int) | ||
| op=Operator(pde, subs=grid.spacing_map) | ||
| pde = [resolve_method(time_int)(eq_rhs, u_multi_stage)] | ||
| op = Operator(pde, subs=grid.spacing_map) | ||
| op(dt=0.01, time=1) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again should assert a norm. Can also be consolidated with the previous test via parameterisation
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The same, the idea is to check if op.apply() executes without an error... |
||
|
|
||
| # Solving now using Devito's standard time solver | ||
|
|
@@ -157,6 +193,4 @@ def test_multistage_op_solve_computing(time_int='RK44'): | |
| op = Operator(pde, subs=grid.spacing_map) | ||
| op(dt=0.01, time=1) | ||
|
|
||
|
EdCaunt marked this conversation as resolved.
|
||
| return max(abs(u_time_int.data[0,:]-u.data[0,:])) | ||
|
|
||
| # test_multistage_op_solve_computing() | ||
| return max(abs(u_multi_stage.data[0, :] - u.data[0, :])) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you do
return [_lower_multistage(expr, **kwargs) for i in exprs for expr in i]?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did something like that...