11from collections .abc import Callable
22from dataclasses import dataclass , field
3-
43import torch .nn as nn
4+ from .base_config import BaseConfig
55
66
77@dataclass
8- class DefaultMambAttentionConfig :
8+ class DefaultMambAttentionConfig ( BaseConfig ) :
99 """Configuration class for the Default Mambular Attention model with predefined hyperparameters.
1010
1111 Parameters
1212 ----------
13- lr : float, default=1e-04
14- Learning rate for the optimizer.
15- lr_patience : int, default=10
16- Number of epochs with no improvement after which learning rate will be reduced.
17- weight_decay : float, default=1e-06
18- Weight decay (L2 penalty) for the optimizer.
19- lr_factor : float, default=0.1
20- Factor by which the learning rate will be reduced.
2113 d_model : int, default=64
2214 Dimensionality of the model.
2315 n_layers : int, default=4
@@ -58,22 +50,6 @@ class DefaultMambAttentionConfig:
5850 Type of normalization used in the model.
5951 activation : callable, default=nn.SiLU()
6052 Activation function for the model.
61- layer_norm_eps : float, default=1e-05
62- Epsilon value for layer normalization.
63- num_embedding_activation : callable, default=nn.ReLU()
64- Activation function for numerical embeddings.
65- embedding_type : str, default="linear"
66- Type of embedding to use ('linear', etc.).
67- embedding_bias : bool, default=False
68- Whether to use bias in the embedding layers.
69- plr_lite : bool, default=False
70- Whether to use a lightweight version of Piecewise Linear Regression (PLR).
71- n_frequencies : int, default=48
72- Number of frequencies for PLR embeddings.
73- frequencies_init_scale : float, default=0.01
74- Initial scale for frequency parameters in embeddings.
75- layer_norm_after_embedding : bool, default=False
76- Whether to apply layer normalization after embedding layers.
7753 head_layer_sizes : list, default=()
7854 Sizes of the fully connected layers in the model's head.
7955 head_dropout : float, default=0.5
@@ -106,12 +82,6 @@ class DefaultMambAttentionConfig:
10682 Number of attention layers in the model.
10783 """
10884
109- # Optimizer Parameters
110- lr : float = 1e-04
111- lr_patience : int = 10
112- weight_decay : float = 1e-06
113- lr_factor : float = 0.1
114-
11585 # Architecture Parameters
11686 d_model : int = 64
11787 n_layers : int = 4
@@ -133,16 +103,6 @@ class DefaultMambAttentionConfig:
133103 dt_init_floor : float = 1e-04
134104 norm : str = "LayerNorm"
135105 activation : Callable = nn .SiLU () # noqa: RUF009
136- layer_norm_eps : float = 1e-05
137-
138- # Embedding Parameters
139- num_embedding_activation : Callable = nn .ReLU () # noqa: RUF009
140- embedding_type : str = "linear"
141- embedding_bias : bool = False
142- plr_lite : bool = False
143- n_frequencies : int = 48
144- frequencies_init_scale : float = 0.01
145- layer_norm_after_embedding : bool = False
146106
147107 # Head Parameters
148108 head_layer_sizes : list = field (default_factory = list )
0 commit comments