1- from .equation import Eq
2- from .dense import Function
1+ from devito . types .equation import Eq
2+ from devito . types .dense import Function
33from devito .symbolics import uxreplace
44from numpy import number
5-
6- from .array import Array # Trying Array
7-
5+ from devito .types .array import Array
6+ from types import MappingProxyType
87
98method_registry = {}
109
@@ -22,38 +21,63 @@ def resolve_method(method):
2221
2322class MultiStage (Eq ):
2423 """
25- Abstract base class for multi-stage time integration methods
26- (e.g., Runge-Kutta schemes) in Devito.
27-
28- This class represents a symbolic equation of the form `target = rhs`
29- and provides a mechanism to associate it with a time integration
30- scheme. The specific integration behavior must be implemented by
31- subclasses via the `_evaluate` method.
32-
33- Parameters
34- ----------
35- lhs : expr-like
36- The left-hand side of the equation, typically a time-updated Function
37- (e.g., `u.forward`).
38- rhs : expr-like, optional
39- The right-hand side of the equation to integrate. Defaults to 0.
40- subdomain : SubDomain, optional
41- A subdomain over which the equation applies.
42- coefficients : dict, optional
43- Optional dictionary of symbolic coefficients for the integration.
44- implicit_dims : tuple, optional
45- Additional dimensions that should be treated implicitly in the equation.
46- **kwargs : dict
47- Additional keyword arguments, such as time integration method selection.
48-
49- Notes
50- -----
51- Subclasses must override the `_evaluate()` method to return a sequence
52- of update expressions for each stage in the integration process.
53- """
54-
55- def __new__ (cls , lhs , rhs = 0 , subdomain = None , coefficients = None , implicit_dims = None , ** kwargs ):
56- return super ().__new__ (cls , lhs , rhs = rhs , subdomain = subdomain , coefficients = coefficients , implicit_dims = implicit_dims , ** kwargs )
24+ Abstract base class for multi-stage time integration methods
25+ (e.g., Runge-Kutta schemes) in Devito.
26+
27+ This class represents a symbolic equation of the form `target = rhs`
28+ and provides a mechanism to associate it with a time integration
29+ scheme. The specific integration behavior must be implemented by
30+ subclasses via the `_evaluate` method.
31+
32+ Parameters
33+ ----------
34+ lhs : expr-like
35+ The left-hand side of the equation, typically a time-updated Function
36+ (e.g., `u.forward`).
37+ rhs : expr-like, optional
38+ The right-hand side of the equation to integrate. Defaults to 0.
39+ subdomain : SubDomain, optional
40+ A subdomain over which the equation applies.
41+ coefficients : dict, optional
42+ Optional dictionary of symbolic coefficients for the integration.
43+ implicit_dims : tuple, optional
44+ Additional dimensions that should be treated implicitly in the equation.
45+ **kwargs : dict
46+ Additional keyword arguments, such as time integration method selection.
47+
48+ Notes
49+ -----
50+ Subclasses must override the `_evaluate()` method to return a sequence
51+ of update expressions for each stage in the integration process.
52+ """
53+
54+ def __new__ (cls , lhs , rhs , ** kwargs ):
55+ if not isinstance (lhs , list ):
56+ lhs = [lhs ]
57+ rhs = [rhs ]
58+ obj = super ().__new__ (cls , lhs [0 ], rhs [0 ], ** kwargs )
59+
60+ # Store all equations
61+ obj ._eq = [Eq (lhs [i ], rhs [i ]) for i in range (len (lhs ))]
62+ obj ._lhs = lhs
63+ obj ._rhs = rhs
64+
65+ return obj
66+
67+ @property
68+ def eq (self ):
69+ """Return the full list of equations."""
70+ return self ._eq
71+
72+ @property
73+ def lhs (self ):
74+ """Return list of left-hand sides."""
75+ return self ._lhs
76+
77+ @property
78+ def rhs (self ):
79+ """Return list of right-hand sides."""
80+ return self ._rhs
5781
5882 def _evaluate (self , ** kwargs ):
5983 raise NotImplementedError (
@@ -91,7 +115,7 @@ class RK(MultiStage):
91115 Number of stages in the RK method, inferred from `b`.
92116 """
93117
94- def __init__ (self , a : list [list [float | number ]], b : list [float | number ], c : list [float | number ], ** kwargs ) -> None :
118+ def __init__ (self , a : list [list [float | number ]], b : list [float | number ], c : list [float | number ], lhs , rhs , ** kwargs ) -> None :
95119 self .a , self .b , self .c = a , b , c
96120
97121 @property
@@ -113,32 +137,30 @@ def _evaluate(self, **kwargs):
113137 - `s` stage equations of the form `k_i = rhs evaluated at intermediate state`
114138 - 1 final update equation of the form `u.forward = u + dt * sum(b_i * k_i)`
115139 """
116-
117- u = self .lhs .function
118- rhs = self .rhs
119- grid = u .grid
120- t = grid .time_dim
140+ n_eq = len (self .eq )
141+ u = [i .function for i in self .lhs ]
142+ grid = [u [i ].grid for i in range (n_eq )]
143+ t = grid [0 ].time_dim
121144 dt = t .spacing
122145
123146 # Create temporary Functions to hold each stage
124- # k = [Array(name=f'{kwargs.get('sregistry').make_name(prefix='k')}', dimensions=grid.shape, grid=grid, dtype=u.dtype) for i in range(self.s)] # Trying Array
125- k = [Function (name = f'{ kwargs .get ('sregistry' ).make_name (prefix = 'k' )} ' , grid = grid , space_order = u .space_order , dtype = u .dtype )
126- for i in range (self .s )]
147+ k = [[Array (name = f'{ kwargs .get ('sregistry' ).make_name (prefix = 'k' )} ' , dimensions = grid [j ].dimensions , grid = grid [j ], dtype = u [j ].dtype ) for i in range (self .s )]
148+ for j in range (n_eq )]
127149
128150 stage_eqs = []
129151
130152 # Build each stage
131153 for i in range (self .s ):
132- u_temp = u + dt * sum (aij * kj for aij , kj in zip (self .a [i ][:i ], k [:i ]))
154+ u_temp = [ u [ l ] + dt * sum (aij * kj for aij , kj in zip (self .a [i ][:i ], k [l ][ :i ])) for l in range ( n_eq )]
133155 t_shift = t + self .c [i ] * dt
134156
135157 # Evaluate RHS at intermediate value
136- stage_rhs = uxreplace (rhs , {u : u_temp , t : t_shift })
137- stage_eqs .append (Eq (k [i ], stage_rhs ))
158+ stage_rhs = [ uxreplace (self . rhs [ l ] , {** { u [ m ] : u_temp [ m ] for m in range ( n_eq )} , t : t_shift }) for l in range ( n_eq )]
159+ [ stage_eqs .append (Eq (k [l ][ i ], stage_rhs [ l ])) for l in range ( n_eq )]
138160
139161 # Final update: u.forward = u + dt * sum(b_i * k_i)
140- u_next = u + dt * sum (bi * ki for bi , ki in zip (self .b , k ))
141- stage_eqs .append (Eq (u .forward , u_next ))
162+ u_next = [ u [ l ] + dt * sum (bi * ki for bi , ki in zip (self .b , k [ l ])) for l in range ( n_eq )]
163+ [ stage_eqs .append (Eq (u [ l ] .forward , u_next [ l ])) for l in range ( n_eq )]
142164
143165 return stage_eqs
144166
@@ -166,8 +188,8 @@ class RK44(RK):
166188 b = [1 / 6 , 1 / 3 , 1 / 3 , 1 / 6 ]
167189 c = [0 , 1 / 2 , 1 / 2 , 1 ]
168190
169- def __init__ (self , * args , ** kwargs ):
170- super ().__init__ (a = self .a , b = self .b , c = self .c , ** kwargs )
191+ def __init__ (self , lhs , rhs , ** kwargs ):
192+ super ().__init__ (a = self .a , b = self .b , c = self .c , lhs = lhs , rhs = rhs , ** kwargs )
171193
172194
173195@register_method
@@ -354,4 +376,7 @@ def _evaluate(self, **kwargs):
354376 u_next = u + dt * sum (bi * ki for bi , ki in zip (self .b , k ))
355377 stage_eqs .append (Eq (u .forward , u_next ))
356378
357- return stage_eqs
379+ return stage_eqs
380+
381+
382+ method_registry = MappingProxyType (method_registry )
0 commit comments