@@ -41,6 +41,8 @@ def __init__(
4141 dt_init_floor = 1e-04 ,
4242 norm = RMSNorm ,
4343 activation = F .silu ,
44+ bidirectional = False ,
45+ use_learnable_interaction = False ,
4446 ):
4547 super ().__init__ ()
4648
@@ -62,6 +64,8 @@ def __init__(
6264 dt_init_floor ,
6365 norm ,
6466 activation ,
67+ bidirectional ,
68+ use_learnable_interaction ,
6569 )
6670 for _ in range (n_layers )
6771 ]
@@ -99,6 +103,8 @@ def __init__(
99103 dt_init_floor = 1e-04 ,
100104 norm = RMSNorm ,
101105 activation = F .silu ,
106+ bidirectional = False ,
107+ use_learnable_interaction = False ,
102108 ):
103109 super ().__init__ ()
104110
@@ -141,6 +147,8 @@ def __init__(
141147 dt_min = dt_min ,
142148 dt_init_floor = dt_init_floor ,
143149 activation = activation ,
150+ bidirectional = bidirectional ,
151+ use_learnable_interaction = use_learnable_interaction ,
144152 )
145153 self .norm = norm (d_model )
146154
@@ -153,14 +161,14 @@ class MambaBlock(nn.Module):
153161 """MambaBlock module containing the main computational components.
154162
155163 Attributes:
156- config (MambularConfig): Configuration object for the MambaBlock.
157164 in_proj (nn.Linear): Linear projection for input.
158165 conv1d (nn.Conv1d): 1D convolutional layer.
159166 x_proj (nn.Linear): Linear projection for input-dependent tensors.
160167 dt_proj (nn.Linear): Linear projection for dynamical time.
161168 A_log (nn.Parameter): Logarithmically stored A tensor.
162169 D (nn.Parameter): Tensor for D component.
163170 out_proj (nn.Linear): Linear projection for output.
171+ learnable_interaction (LearnableFeatureInteraction): Learnable feature interaction layer.
164172 """
165173
166174 def __init__ (
@@ -179,88 +187,154 @@ def __init__(
179187 dt_min = 1e-03 ,
180188 dt_init_floor = 1e-04 ,
181189 activation = F .silu ,
190+ bidirectional = False ,
191+ use_learnable_interaction = False ,
182192 ):
183193 super ().__init__ ()
184194 self .d_inner = d_model * expand_factor
195+ self .bidirectional = bidirectional
196+ self .use_learnable_interaction = use_learnable_interaction
185197
186- self .in_proj = nn .Linear (d_model , 2 * self .d_inner , bias = bias )
198+ self .in_proj_fwd = nn .Linear (d_model , 2 * self .d_inner , bias = bias )
199+ if self .bidirectional :
200+ self .in_proj_bwd = nn .Linear (d_model , 2 * self .d_inner , bias = bias )
187201
188- self .conv1d = nn .Conv1d (
202+ self .conv1d_fwd = nn .Conv1d (
189203 in_channels = self .d_inner ,
190204 out_channels = self .d_inner ,
191205 kernel_size = d_conv ,
192206 bias = conv_bias ,
193207 groups = self .d_inner ,
194208 padding = d_conv - 1 ,
195209 )
210+ if self .bidirectional :
211+ self .conv1d_bwd = nn .Conv1d (
212+ in_channels = self .d_inner ,
213+ out_channels = self .d_inner ,
214+ kernel_size = d_conv ,
215+ bias = conv_bias ,
216+ groups = self .d_inner ,
217+ padding = d_conv - 1 ,
218+ )
196219
197220 self .dropout = nn .Dropout (dropout )
198221 self .activation = activation
199222
200- self .x_proj = nn .Linear (self .d_inner , dt_rank + 2 * d_state , bias = False )
223+ if self .use_learnable_interaction :
224+ self .learnable_interaction = LearnableFeatureInteraction (self .d_inner )
225+
226+ self .x_proj_fwd = nn .Linear (self .d_inner , dt_rank + 2 * d_state , bias = False )
227+ if self .bidirectional :
228+ self .x_proj_bwd = nn .Linear (self .d_inner , dt_rank + 2 * d_state , bias = False )
201229
202- self .dt_proj = nn .Linear (dt_rank , self .d_inner , bias = True )
230+ self .dt_proj_fwd = nn .Linear (dt_rank , self .d_inner , bias = True )
231+ if self .bidirectional :
232+ self .dt_proj_bwd = nn .Linear (dt_rank , self .d_inner , bias = True )
203233
204234 dt_init_std = dt_rank ** - 0.5 * dt_scale
205235 if dt_init == "constant" :
206- nn .init .constant_ (self .dt_proj .weight , dt_init_std )
236+ nn .init .constant_ (self .dt_proj_fwd .weight , dt_init_std )
237+ if self .bidirectional :
238+ nn .init .constant_ (self .dt_proj_bwd .weight , dt_init_std )
207239 elif dt_init == "random" :
208- nn .init .uniform_ (self .dt_proj .weight , - dt_init_std , dt_init_std )
240+ nn .init .uniform_ (self .dt_proj_fwd .weight , - dt_init_std , dt_init_std )
241+ if self .bidirectional :
242+ nn .init .uniform_ (self .dt_proj_bwd .weight , - dt_init_std , dt_init_std )
209243 else :
210244 raise NotImplementedError
211245
212- dt = torch .exp (
246+ dt_fwd = torch .exp (
213247 torch .rand (self .d_inner ) * (math .log (dt_max ) - math .log (dt_min ))
214248 + math .log (dt_min )
215249 ).clamp (min = dt_init_floor )
216- inv_dt = dt + torch .log (- torch .expm1 (- dt ))
250+ inv_dt_fwd = dt_fwd + torch .log (- torch .expm1 (- dt_fwd ))
217251 with torch .no_grad ():
218- self .dt_proj .bias .copy_ (inv_dt )
252+ self .dt_proj_fwd .bias .copy_ (inv_dt_fwd )
253+
254+ if self .bidirectional :
255+ dt_bwd = torch .exp (
256+ torch .rand (self .d_inner ) * (math .log (dt_max ) - math .log (dt_min ))
257+ + math .log (dt_min )
258+ ).clamp (min = dt_init_floor )
259+ inv_dt_bwd = dt_bwd + torch .log (- torch .expm1 (- dt_bwd ))
260+ with torch .no_grad ():
261+ self .dt_proj_bwd .bias .copy_ (inv_dt_bwd )
219262
220263 A = torch .arange (1 , d_state + 1 , dtype = torch .float32 ).repeat (self .d_inner , 1 )
221- self .A_log = nn .Parameter (torch .log (A ))
222- self .D = nn .Parameter (torch .ones (self .d_inner ))
264+ self .A_log_fwd = nn .Parameter (torch .log (A ))
265+ if self .bidirectional :
266+ self .A_log_bwd = nn .Parameter (torch .log (A ))
267+
268+ self .D_fwd = nn .Parameter (torch .ones (self .d_inner ))
269+ if self .bidirectional :
270+ self .D_bwd = nn .Parameter (torch .ones (self .d_inner ))
271+
223272 self .out_proj = nn .Linear (self .d_inner , d_model , bias = bias )
224273 self .dt_rank = dt_rank
225274 self .d_state = d_state
226275
227276 def forward (self , x ):
228277 _ , L , _ = x .shape
229278
230- xz = self .in_proj (x )
231- x , z = xz .chunk (2 , dim = - 1 )
279+ xz_fwd = self .in_proj_fwd (x )
280+ x_fwd , z_fwd = xz_fwd .chunk (2 , dim = - 1 )
232281
233- x = x .transpose (1 , 2 )
234- x = self .conv1d ( x )[:, :, :L ]
235- x = x .transpose (1 , 2 )
282+ x_fwd = x_fwd .transpose (1 , 2 )
283+ x_fwd = self .conv1d_fwd ( x_fwd )[:, :, :L ]
284+ x_fwd = x_fwd .transpose (1 , 2 )
236285
237- x = self .activation ( x )
238- x = self .dropout (x )
239- y = self . ssm ( x )
286+ if self .bidirectional :
287+ xz_bwd = self .in_proj_bwd (x )
288+ x_bwd , z_bwd = xz_bwd . chunk ( 2 , dim = - 1 )
240289
241- z = self .activation (z )
242- z = self .dropout (z )
290+ x_bwd = x_bwd .transpose (1 , 2 )
291+ x_bwd = self .conv1d_bwd (x_bwd )[:, :, :L ]
292+ x_bwd = x_bwd .transpose (1 , 2 )
243293
244- output = y * z
245- output = self .out_proj (output )
294+ if self .use_learnable_interaction :
295+ x_fwd = self .learnable_interaction (x_fwd )
296+ if self .bidirectional :
297+ x_bwd = self .learnable_interaction (x_bwd )
246298
247- return output
299+ x_fwd = self .activation (x_fwd )
300+ x_fwd = self .dropout (x_fwd )
301+ y_fwd = self .ssm (x_fwd , forward = True )
248302
249- def ssm (self , x ):
250- A = - torch .exp (self .A_log .float ())
251- D = self .D .float ()
303+ if self .bidirectional :
304+ x_bwd = self .activation (x_bwd )
305+ x_bwd = self .dropout (x_bwd )
306+ y_bwd = self .ssm (torch .flip (x_bwd , [1 ]), forward = False )
307+ y = y_fwd + torch .flip (y_bwd , [1 ])
308+ else :
309+ y = y_fwd
252310
253- deltaBC = self .x_proj (x )
311+ z_fwd = self .activation (z_fwd )
312+ z_fwd = self .dropout (z_fwd )
254313
255- delta , B , C = torch .split (
256- deltaBC ,
257- [self .dt_rank , self .d_state , self .d_state ],
258- dim = - 1 ,
259- )
260- delta = F .softplus (self .dt_proj (delta ))
314+ output = y * z_fwd
315+ output = self .out_proj (output )
261316
262- y = self .selective_scan_seq (x , delta , A , B , C , D )
317+ return output
318+
319+ def ssm (self , x , forward = True ):
320+ if forward :
321+ A = - torch .exp (self .A_log_fwd .float ())
322+ D = self .D_fwd .float ()
323+ deltaBC = self .x_proj_fwd (x )
324+ delta , B , C = torch .split (
325+ deltaBC , [self .dt_rank , self .d_state , self .d_state ], dim = - 1
326+ )
327+ delta = F .softplus (self .dt_proj_fwd (delta ))
328+ else :
329+ A = - torch .exp (self .A_log_bwd .float ())
330+ D = self .D_bwd .float ()
331+ deltaBC = self .x_proj_bwd (x )
332+ delta , B , C = torch .split (
333+ deltaBC , [self .dt_rank , self .d_state , self .d_state ], dim = - 1
334+ )
335+ delta = F .softplus (self .dt_proj_bwd (delta ))
263336
337+ y = self .selective_scan_seq (x , delta , A , B , C , D )
264338 return y
265339
266340 def selective_scan_seq (self , x , delta , A , B , C , D ):
@@ -285,3 +359,15 @@ def selective_scan_seq(self, x, delta, A, B, C, D):
285359 y = y + D * x
286360
287361 return y
362+
363+
364+ class LearnableFeatureInteraction (nn .Module ):
365+ def __init__ (self , n_vars ):
366+ super ().__init__ ()
367+ self .interaction_weights = nn .Parameter (torch .Tensor (n_vars , n_vars ))
368+ nn .init .xavier_uniform_ (self .interaction_weights )
369+
370+ def forward (self , x ):
371+ batch_size , n_vars , d_model = x .size ()
372+ interactions = torch .matmul (x , self .interaction_weights )
373+ return interactions .view (batch_size , n_vars , d_model )
0 commit comments