Skip to content

Commit e9d2000

Browse files
committed
merging two classes of Runge-Kutta
1 parent 552fd7f commit e9d2000

1 file changed

Lines changed: 16 additions & 20 deletions

File tree

devito/types/multistage.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def _evaluate(self, **kwargs):
176176
f"_evaluate() must be implemented in the subclass {self.__class__.__name__}")
177177

178178

179-
class RungeKutta(MultiStage):
179+
class TableauRungeKutta(MultiStage):
180180
"""
181181
Base class for explicit Runge-Kutta (RK) time integration methods defined
182182
via a Butcher tableau.
@@ -210,8 +210,14 @@ class RungeKutta(MultiStage):
210210
CoeffsBC = tuple[float | np.number, ...]
211211
CoeffsA = tuple[CoeffsBC, ...]
212212

213-
def __init__(self, a: CoeffsA, b: CoeffsBC, c: CoeffsBC, lhs, rhs, **kwargs) -> None:
214-
self.a, self.b, self.c = a, b, c
213+
def __init__(self, lhs, rhs, a: CoeffsA = None, b: CoeffsBC = None,
214+
c: CoeffsBC = None, **kwargs) -> None:
215+
self.a = a if a is not None else getattr(self, 'a', None)
216+
self.b = b if b is not None else getattr(self, 'b', None)
217+
self.c = c if c is not None else getattr(self, 'c', None)
218+
219+
if self.a is None or self.b is None or self.c is None:
220+
raise ValueError("TableauRungeKutta requires coefficients 'a', 'b', and 'c'.")
215221

216222
@property
217223
def s(self):
@@ -240,7 +246,8 @@ def _evaluate(self, **kwargs):
240246
for j in range(self.n_eq):
241247
k_j = []
242248
for _ in range(self.s):
243-
k_j.append(TimeFunction(name=f'{sregistry.make_name(prefix="k")}', grid=self.lhs[j].grid,
249+
k_name = sregistry.make_name(prefix="k")
250+
k_j.append(TimeFunction(name=k_name, grid=self.lhs[j].grid,
244251
space_order=self.lhs[j].space_order, time_order=0, dtype=self.lhs[j].dtype))
245252
k.append(k_j)
246253

@@ -265,21 +272,10 @@ def _evaluate(self, **kwargs):
265272
for l in range(self.n_eq)])
266273

267274
return stage_eqs
268-
269-
270-
class FixedTableauRungeKutta(RungeKutta):
271-
"""
272-
Runge-Kutta variant with class-defined Butcher tableau coefficients.
273-
274-
Subclasses must define class attributes `a`, `b`, and `c`.
275-
"""
276-
277-
def __init__(self, lhs, rhs, **kwargs):
278-
super().__init__(a=self.a, b=self.b, c=self.c, lhs=lhs, rhs=rhs, **kwargs)
279-
275+
280276

281277
@register_method(aliases=['RK44'])
282-
class RungeKutta44(FixedTableauRungeKutta):
278+
class RungeKutta44(TableauRungeKutta):
283279
"""
284280
Classic 4th-order Runge-Kutta (RK4) time integration method.
285281
@@ -302,7 +298,7 @@ class RungeKutta44(FixedTableauRungeKutta):
302298
c = (0, 1/2, 1/2, 1)
303299

304300
@register_method(aliases=['RK32'])
305-
class RungeKutta32(FixedTableauRungeKutta):
301+
class RungeKutta32(TableauRungeKutta):
306302
"""
307303
3 stages 2nd-order Runge-Kutta (RK32) time integration method.
308304
@@ -324,7 +320,7 @@ class RungeKutta32(FixedTableauRungeKutta):
324320
c = (0, 1/2, 1/2)
325321

326322
@register_method(aliases=['RK97'])
327-
class RungeKutta97(FixedTableauRungeKutta):
323+
class RungeKutta97(TableauRungeKutta):
328324
"""
329325
9 stages 7th-order Runge-Kutta (RK97) time integration method.
330326
@@ -525,7 +521,7 @@ def _evaluate(self, **kwargs):
525521
# update stage equations with source contributions
526522
stage_eqs.extend([Eq(k_j, k_old_j+mu*self.dt*src_lhs_j) for k_j, k_old_j, src_lhs_j in zip(k, k_old, src_lhs)])
527523

528-
# include the last stage to the final approximation with the corresponding stage coefficient (alpha[i])
524+
# include the last stage to the final approximation with the corresponding alpha coefficient
529525
stage_eqs.extend([Eq(lhs_j.forward, lhs_j.forward+k_j*alpha[i]) for lhs_j, k_j in zip(self.lhs, k)])
530526

531527
# Final Runge-Kutta updates

0 commit comments

Comments
 (0)