@@ -176,7 +176,7 @@ def _evaluate(self, **kwargs):
176176 f"_evaluate() must be implemented in the subclass { self .__class__ .__name__ } " )
177177
178178
179- class RungeKutta (MultiStage ):
179+ class TableauRungeKutta (MultiStage ):
180180 """
181181 Base class for explicit Runge-Kutta (RK) time integration methods defined
182182 via a Butcher tableau.
@@ -210,8 +210,14 @@ class RungeKutta(MultiStage):
210210 CoeffsBC = tuple [float | np .number , ...]
211211 CoeffsA = tuple [CoeffsBC , ...]
212212
213- def __init__ (self , a : CoeffsA , b : CoeffsBC , c : CoeffsBC , lhs , rhs , ** kwargs ) -> None :
214- self .a , self .b , self .c = a , b , c
213+ def __init__ (self , lhs , rhs , a : CoeffsA = None , b : CoeffsBC = None ,
214+ c : CoeffsBC = None , ** kwargs ) -> None :
215+ self .a = a if a is not None else getattr (self , 'a' , None )
216+ self .b = b if b is not None else getattr (self , 'b' , None )
217+ self .c = c if c is not None else getattr (self , 'c' , None )
218+
219+ if self .a is None or self .b is None or self .c is None :
220+ raise ValueError ("TableauRungeKutta requires coefficients 'a', 'b', and 'c'." )
215221
216222 @property
217223 def s (self ):
@@ -240,7 +246,8 @@ def _evaluate(self, **kwargs):
240246 for j in range (self .n_eq ):
241247 k_j = []
242248 for _ in range (self .s ):
243- k_j .append (TimeFunction (name = f'{ sregistry .make_name (prefix = "k" )} ' , grid = self .lhs [j ].grid ,
249+ k_name = sregistry .make_name (prefix = "k" )
250+ k_j .append (TimeFunction (name = k_name , grid = self .lhs [j ].grid ,
244251 space_order = self .lhs [j ].space_order , time_order = 0 , dtype = self .lhs [j ].dtype ))
245252 k .append (k_j )
246253
@@ -265,21 +272,10 @@ def _evaluate(self, **kwargs):
265272 for l in range (self .n_eq )])
266273
267274 return stage_eqs
268-
269-
270- class FixedTableauRungeKutta (RungeKutta ):
271- """
272- Runge-Kutta variant with class-defined Butcher tableau coefficients.
273-
274- Subclasses must define class attributes `a`, `b`, and `c`.
275- """
276-
277- def __init__ (self , lhs , rhs , ** kwargs ):
278- super ().__init__ (a = self .a , b = self .b , c = self .c , lhs = lhs , rhs = rhs , ** kwargs )
279-
275+
280276
281277@register_method (aliases = ['RK44' ])
282- class RungeKutta44 (FixedTableauRungeKutta ):
278+ class RungeKutta44 (TableauRungeKutta ):
283279 """
284280 Classic 4th-order Runge-Kutta (RK4) time integration method.
285281
@@ -302,7 +298,7 @@ class RungeKutta44(FixedTableauRungeKutta):
302298 c = (0 , 1 / 2 , 1 / 2 , 1 )
303299
304300@register_method (aliases = ['RK32' ])
305- class RungeKutta32 (FixedTableauRungeKutta ):
301+ class RungeKutta32 (TableauRungeKutta ):
306302 """
307303 3 stages 2nd-order Runge-Kutta (RK32) time integration method.
308304
@@ -324,7 +320,7 @@ class RungeKutta32(FixedTableauRungeKutta):
324320 c = (0 , 1 / 2 , 1 / 2 )
325321
326322@register_method (aliases = ['RK97' ])
327- class RungeKutta97 (FixedTableauRungeKutta ):
323+ class RungeKutta97 (TableauRungeKutta ):
328324 """
329325 9 stages 7th-order Runge-Kutta (RK97) time integration method.
330326
@@ -525,7 +521,7 @@ def _evaluate(self, **kwargs):
525521 # update stage equations with source contributions
526522 stage_eqs .extend ([Eq (k_j , k_old_j + mu * self .dt * src_lhs_j ) for k_j , k_old_j , src_lhs_j in zip (k , k_old , src_lhs )])
527523
528- # include the last stage to the final approximation with the corresponding stage coefficient (alpha[i])
524+ # include the last stage to the final approximation with the corresponding alpha coefficient
529525 stage_eqs .extend ([Eq (lhs_j .forward , lhs_j .forward + k_j * alpha [i ]) for lhs_j , k_j in zip (self .lhs , k )])
530526
531527 # Final Runge-Kutta updates
0 commit comments