@@ -198,6 +198,9 @@ class WaveletSource(PointSource):
198198 Amplitude of the wavelet (defaults to 1).
199199 t0 : float, optional
200200 Firing time (defaults to 1 / f0)
201+ wavelet: str, optional
202+ The type of wavelet to generate one of
203+ {'gauss_soliton', 'dgauss', 'ricker', 'gabor'}
201204 """
202205
203206 __rkwargs__ = PointSource .__rkwargs__ + ['f0' , 'a' , 't0' ]
@@ -217,12 +220,13 @@ def __init_finalize__(self, *args, **kwargs):
217220 self .wavelet_type = kwargs .get ('wavelet' )
218221 self .wavelet_kwargs = {}
219222
220- if self .wavelet_type == 'dgauss' :
221- self .wavelet_kwargs ['n' ] = kwargs .get ('n' , 1 )
223+ if isinstance (self .wavelet_type , str ):
224+ if self .wavelet_type == 'dgauss' :
225+ self .wavelet_kwargs ['n' ] = kwargs .get ('n' , 1 )
222226
223- if self .wavelet_type == 'gabor' :
224- self .wavelet_kwargs ['gamma' ] = kwargs .get ('gamma' , 1 )
225- self .wavelet_kwargs ['phi' ] = kwargs .get ('phi' , 0 )
227+ if self .wavelet_type == 'gabor' :
228+ self .wavelet_kwargs ['gamma' ] = kwargs .get ('gamma' , 1 )
229+ self .wavelet_kwargs ['phi' ] = kwargs .get ('phi' , 0 )
226230
227231 if not self .alias :
228232 for p in range (kwargs ['npoint' ]):
@@ -233,14 +237,16 @@ def wavelet(self):
233237 """
234238 Return a wavelet with a peak frequency ``f0`` at time ``t0``.
235239 """
236- if self .wavelet_type :
240+ if isinstance ( self .wavelet_type , str ) :
237241 return wavelet [self .wavelet_type ](
238242 self .time_values ,
239243 self .f0 ,
240244 1 if self .a is None else self .a ,
241245 self .t0 ,
242246 ** self .wavelet_kwargs
243247 )
248+ elif any (self .wavelet_type ):
249+ return self .wavelet_type
244250 else :
245251 raise NotImplementedError ('Wavelet not defined' )
246252
0 commit comments