-
Notifications
You must be signed in to change notification settings - Fork 100
Expand file tree
/
Copy pathkan_block.py
More file actions
158 lines (133 loc) · 5.87 KB
/
kan_block.py
File metadata and controls
158 lines (133 loc) · 5.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""Module for the Kolmogorov-Arnold Network block."""
import torch
from pina._src.model.vectorized_spline import VectorizedSpline
from pina._src.core.utils import check_consistency, check_positive_integer
class KANBlock(torch.nn.Module):
"""
The inner block of the Kolmogorov-Arnold Network (KAN).
The block applies a spline transformation to the input, optionally combined
with a linear transformation of a base activation function. The output is
aggregated across input dimensions to produce the final output.
.. seealso::
**Original reference**:
Liu Z., Wang Y., Vaidya S., Ruehle F., Halverson J., Soljacic M.,
Hou T., Tegmark M. (2025).
*KAN: Kolmogorov-Arnold Networks*.
DOI: `arXiv preprint arXiv:2404.19756.
<https://arxiv.org/abs/2404.19756>`_
"""
def __init__(
self,
input_dimensions,
output_dimensions,
spline_order=3,
n_knots=10,
grid_range=[0, 1],
base_function=torch.nn.SiLU,
use_base_linear=True,
use_bias=True,
init_scale_spline=1e-2,
init_scale_base=1.0,
):
"""
Initialization of the :class:`KANBlock` class.
:param int input_dimensions: The number of input features.
:param int output_dimensions: The number of output features.
:param int spline_order: The order of each spline basis function.
Default is 3 (cubic splines).
:param int n_knots: The number of knots for each spline basis function.
Default is 10.
:param grid_range: The range for the spline knots. It must be either a
list or a tuple of the form [min, max]. Default is [0, 1].
:type grid_range: list | tuple.
:param torch.nn.Module base_function: The base activation function to be
applied to the input before the linear transformation. Default is
:class:`torch.nn.SiLU`.
:param bool use_base_linear: Whether to include a linear transformation
of the base function output. Default is True.
:param bool use_bias: Whether to include a bias term in the output.
Default is True.
:param init_scale_spline: The scale for initializing each spline
control points. Default is 1e-2.
:type init_scale_spline: float | int.
:param init_scale_base: The scale for initializing the base linear
weights. Default is 1.0.
:type init_scale_base: float | int.
:raises ValueError: If ``grid_range`` is not of length 2.
"""
super().__init__()
# Check consistency
check_consistency(base_function, torch.nn.Module, subclass=True)
check_positive_integer(input_dimensions, strict=True)
check_positive_integer(output_dimensions, strict=True)
check_positive_integer(spline_order, strict=True)
check_positive_integer(n_knots, strict=True)
check_consistency(use_base_linear, bool)
check_consistency(use_bias, bool)
check_consistency(init_scale_spline, (int, float))
check_consistency(init_scale_base, (int, float))
check_consistency(grid_range, (int, float))
# Raise error if grid_range is not valid
if len(grid_range) != 2:
raise ValueError("Grid must be a list or tuple with two elements.")
# Knots for the spline basis functions
initial_knots = torch.ones(spline_order) * grid_range[0]
final_knots = torch.ones(spline_order) * grid_range[1]
# Number of internal knots
n_internal = max(0, n_knots - 2 * spline_order)
# Internal knots are uniformly spaced in the grid range
internal_knots = torch.linspace(
grid_range[0], grid_range[1], n_internal + 2
)[1:-1]
# Define the knots
knots = torch.cat((initial_knots, internal_knots, final_knots))
knots = knots.unsqueeze(0).repeat(input_dimensions, 1)
# Define the control points for the spline basis functions
control_points = (
torch.randn(
input_dimensions,
output_dimensions,
knots.shape[-1] - spline_order,
)
* init_scale_spline
)
# Define the vectorized spline module
self.spline = VectorizedSpline(
order=spline_order, knots=knots, control_points=control_points
)
# Initialize the base function
self.base_function = base_function()
# Initialize the base linear weights if needed
if use_base_linear:
self.base_weight = torch.nn.Parameter(
torch.randn(output_dimensions, input_dimensions)
* (init_scale_base / (input_dimensions**0.5))
)
else:
self.register_parameter("base_weight", None)
# Initialize the bias term if needed
if use_bias:
self.bias = torch.nn.Parameter(torch.zeros(output_dimensions))
else:
self.register_parameter("bias", None)
def forward(self, x):
"""
Forward pass of the Kolmogorov-Arnold block. The input is passed through
the spline transformation, optionally combined with a linear
transformation of the base function output, and then aggregated across
input dimensions to produce the final output.
:param x: The input tensor for the model.
:type x: torch.Tensor | LabelTensor
:return: The output tensor of the model.
:rtype: torch.Tensor | LabelTensor
"""
y = self.spline(x)
if self.base_weight is not None:
base_x = self.base_function(x)
base_out = torch.einsum("bi,oi->bio", base_x, self.base_weight)
y = y + base_out
# aggregate contributions from all input dimensions
y = y.sum(dim=1)
if self.bias is not None:
y = y + self.bias
return y