|
| 1 | +from collections.abc import Callable |
| 2 | +from dataclasses import dataclass, field |
| 3 | +import torch.nn as nn |
| 4 | +from .base_config import BaseConfig |
| 5 | + |
| 6 | + |
| 7 | +@dataclass |
| 8 | +class DefaultTangosConfig(BaseConfig): |
| 9 | + """Configuration class for the default Multi-Layer Perceptron (TANGOS) model with predefined hyperparameters. |
| 10 | +
|
| 11 | + Parameters |
| 12 | + ---------- |
| 13 | + layer_sizes : list, default=(256, 128, 32) |
| 14 | + Sizes of the layers in the TANGOS. |
| 15 | + activation : callable, default=nn.ReLU() |
| 16 | + Activation function for the TANGOS layers. |
| 17 | + skip_layers : bool, default=False |
| 18 | + Whether to skip layers in the TANGOS. |
| 19 | + dropout : float, default=0.2 |
| 20 | + Dropout rate for regularization. |
| 21 | + use_glu : bool, default=False |
| 22 | + Whether to use Gated Linear Units (GLU) in the TANGOS. |
| 23 | + skip_connections : bool, default=False |
| 24 | + Whether to use skip connections in the TANGOS. |
| 25 | + """ |
| 26 | + |
| 27 | + # Architecture Parameters |
| 28 | + layer_sizes: list = field(default_factory=lambda: [256, 128, 32]) |
| 29 | + activation: Callable = nn.ReLU() # noqa: RUF009 |
| 30 | + skip_layers: bool = False |
| 31 | + dropout: float = 0.2 |
| 32 | + use_glu: bool = False |
| 33 | + skip_connections: bool = False |
| 34 | + lamda1: float = 0.5 |
| 35 | + lamda2: float = 0.1 |
| 36 | + subsample: float = 0.5 |
| 37 | + |
0 commit comments