Skip to content

Commit 56801dd

Browse files
authored
Merge pull request #90 from basf/layer_improvement
Layer improvement
2 parents cc92798 + 19b760c commit 56801dd

28 files changed

Lines changed: 1093 additions & 261 deletions
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import torch.nn as nn
2+
import torch
3+
4+
5+
import torch
6+
import torch.nn as nn
7+
8+
9+
class Reshape(nn.Module):
10+
def __init__(self, j, dim, method="linear"):
11+
super(Reshape, self).__init__()
12+
self.j = j
13+
self.dim = dim
14+
self.method = method
15+
16+
if self.method == "linear":
17+
# Use nn.Linear approach
18+
self.layer = nn.Linear(dim, j * dim)
19+
elif self.method == "embedding":
20+
# Use nn.Embedding approach
21+
self.layer = nn.Embedding(dim, j * dim)
22+
elif self.method == "conv1d":
23+
# Use nn.Conv1d approach
24+
self.layer = nn.Conv1d(in_channels=dim, out_channels=j * dim, kernel_size=1)
25+
else:
26+
raise ValueError(f"Unsupported method '{method}' for reshaping.")
27+
28+
def forward(self, x):
29+
batch_size = x.shape[0]
30+
31+
if self.method == "linear" or self.method == "embedding":
32+
x_reshaped = self.layer(x) # shape: (batch_size, j * dim)
33+
x_reshaped = x_reshaped.view(
34+
batch_size, self.j, self.dim
35+
) # shape: (batch_size, j, dim)
36+
elif self.method == "conv1d":
37+
# For Conv1d, add dummy dimension and reshape
38+
x = x.unsqueeze(-1) # Add dummy dimension for convolution
39+
x_reshaped = self.layer(x) # shape: (batch_size, j * dim, 1)
40+
x_reshaped = x_reshaped.squeeze(-1) # Remove dummy dimension
41+
x_reshaped = x_reshaped.view(
42+
batch_size, self.j, self.dim
43+
) # shape: (batch_size, j, dim)
44+
45+
return x_reshaped
46+
47+
48+
class AttentionNetBlock(nn.Module):
49+
def __init__(
50+
self,
51+
channels,
52+
in_channels,
53+
d_model,
54+
n_heads,
55+
n_layers,
56+
dim_feedforward,
57+
transformer_activation,
58+
output_dim,
59+
attn_dropout,
60+
layer_norm_eps,
61+
norm_first,
62+
bias,
63+
activation,
64+
embedding_activation,
65+
norm_f,
66+
method,
67+
):
68+
super(AttentionNetBlock, self).__init__()
69+
70+
self.reshape = Reshape(channels, in_channels, method)
71+
72+
encoder_layer = nn.TransformerEncoderLayer(
73+
d_model=d_model,
74+
nhead=n_heads,
75+
batch_first=True,
76+
dim_feedforward=dim_feedforward,
77+
dropout=attn_dropout,
78+
activation=transformer_activation,
79+
layer_norm_eps=layer_norm_eps,
80+
norm_first=norm_first,
81+
bias=bias,
82+
)
83+
84+
self.encoder = nn.TransformerEncoder(
85+
encoder_layer,
86+
num_layers=n_layers,
87+
norm=norm_f,
88+
)
89+
90+
self.linear = nn.Linear(d_model, output_dim)
91+
self.activation = activation
92+
self.embedding_activation = embedding_activation
93+
94+
def forward(self, x):
95+
z = self.reshape(x)
96+
x = self.embedding_activation(z)
97+
x = self.encoder(x)
98+
x = z + x
99+
x = torch.sum(x, dim=1)
100+
x = self.linear(x)
101+
x = self.activation(x)
102+
return x
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import torch.nn as nn
2+
import torch
3+
from rotary_embedding_torch import RotaryEmbedding
4+
from einops import rearrange
5+
import torch.nn.functional as F
6+
import numpy as np
7+
8+
9+
class GEGLU(nn.Module):
10+
def forward(self, x):
11+
x, gates = x.chunk(2, dim=-1)
12+
return x * F.gelu(gates)
13+
14+
15+
def FeedForward(dim, mult=4, dropout=0.0):
16+
return nn.Sequential(
17+
nn.LayerNorm(dim),
18+
nn.Linear(dim, dim * mult * 2),
19+
GEGLU(),
20+
nn.Dropout(dropout),
21+
nn.Linear(dim * mult, dim),
22+
)
23+
24+
25+
class Attention(nn.Module):
26+
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, rotary=False):
27+
super().__init__()
28+
inner_dim = dim_head * heads
29+
self.heads = heads
30+
self.scale = dim_head**-0.5
31+
self.norm = nn.LayerNorm(dim)
32+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
33+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
34+
self.dropout = nn.Dropout(dropout)
35+
self.rotary = rotary
36+
dim = np.int64(dim / 2)
37+
self.rotary_embedding = RotaryEmbedding(dim=dim)
38+
39+
def forward(self, x):
40+
h = self.heads
41+
x = self.norm(x)
42+
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
43+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
44+
if self.rotary:
45+
q = self.rotary_embedding.rotate_queries_or_keys(q)
46+
k = self.rotary_embedding.rotate_queries_or_keys(k)
47+
q = q * self.scale
48+
49+
sim = torch.einsum("b h i d, b h j d -> b h i j", q, k)
50+
51+
attn = sim.softmax(dim=-1)
52+
dropped_attn = self.dropout(attn)
53+
54+
out = torch.einsum("b h i j, b h j d -> b h i d", dropped_attn, v)
55+
out = rearrange(out, "b h n d -> b n (h d)", h=h)
56+
out = self.to_out(out)
57+
58+
return out, attn
59+
60+
61+
class Transformer(nn.Module):
62+
def __init__(
63+
self, dim, depth, heads, dim_head, attn_dropout, ff_dropout, rotary=False
64+
):
65+
super().__init__()
66+
self.layers = nn.ModuleList([])
67+
68+
for _ in range(depth):
69+
self.layers.append(
70+
nn.ModuleList(
71+
[
72+
Attention(
73+
dim,
74+
heads=heads,
75+
dim_head=dim_head,
76+
dropout=attn_dropout,
77+
rotary=rotary,
78+
),
79+
FeedForward(dim, dropout=ff_dropout),
80+
]
81+
)
82+
)
83+
84+
def forward(self, x, return_attn=False):
85+
post_softmax_attns = []
86+
87+
for attn, ff in self.layers:
88+
attn_out, post_softmax_attn = attn(x)
89+
post_softmax_attns.append(post_softmax_attn)
90+
91+
x = attn_out + x
92+
x = ff(x) + x
93+
94+
if not return_attn:
95+
return x
96+
97+
return x, torch.stack(post_softmax_attns)
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
5+
class EmbeddingLayer(nn.Module):
6+
def __init__(
7+
self,
8+
num_feature_info,
9+
cat_feature_info,
10+
d_model,
11+
embedding_activation=nn.Identity(),
12+
layer_norm_after_embedding=False,
13+
use_cls=False,
14+
cls_position=0,
15+
):
16+
"""
17+
Embedding layer that handles numerical and categorical embeddings.
18+
19+
Parameters
20+
----------
21+
num_feature_info : dict
22+
Dictionary where keys are numerical feature names and values are their respective input dimensions.
23+
cat_feature_info : dict
24+
Dictionary where keys are categorical feature names and values are the number of categories for each feature.
25+
d_model : int
26+
Dimensionality of the embeddings.
27+
embedding_activation : nn.Module, optional
28+
Activation function to apply after embedding. Default is `nn.Identity()`.
29+
layer_norm_after_embedding : bool, optional
30+
If True, applies layer normalization after embeddings. Default is `False`.
31+
use_cls : bool, optional
32+
If True, includes a class token in the embeddings. Default is `False`.
33+
cls_position : int, optional
34+
Position to place the class token, either at the start (0) or end (1) of the sequence. Default is `0`.
35+
36+
Methods
37+
-------
38+
forward(num_features=None, cat_features=None)
39+
Defines the forward pass of the model.
40+
"""
41+
super(EmbeddingLayer, self).__init__()
42+
43+
self.d_model = d_model
44+
self.embedding_activation = embedding_activation
45+
self.layer_norm_after_embedding = layer_norm_after_embedding
46+
self.use_cls = use_cls
47+
self.cls_position = cls_position
48+
49+
self.num_embeddings = nn.ModuleList(
50+
[
51+
nn.Sequential(
52+
nn.Linear(input_shape, d_model, bias=False),
53+
self.embedding_activation,
54+
)
55+
for feature_name, input_shape in num_feature_info.items()
56+
]
57+
)
58+
59+
self.cat_embeddings = nn.ModuleList(
60+
[
61+
nn.Sequential(
62+
nn.Embedding(num_categories + 1, d_model),
63+
self.embedding_activation,
64+
)
65+
for feature_name, num_categories in cat_feature_info.items()
66+
]
67+
)
68+
69+
if self.use_cls:
70+
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
71+
if layer_norm_after_embedding:
72+
self.embedding_norm = nn.LayerNorm(d_model)
73+
74+
self.seq_len = len(self.num_embeddings) + len(self.cat_embeddings)
75+
76+
def forward(self, num_features=None, cat_features=None):
77+
"""
78+
Defines the forward pass of the model.
79+
80+
Parameters
81+
----------
82+
num_features : Tensor, optional
83+
Tensor containing the numerical features.
84+
cat_features : Tensor, optional
85+
Tensor containing the categorical features.
86+
87+
Returns
88+
-------
89+
Tensor
90+
The output embeddings of the model.
91+
92+
Raises
93+
------
94+
ValueError
95+
If no features are provided to the model.
96+
"""
97+
if self.use_cls:
98+
batch_size = (
99+
cat_features[0].size(0)
100+
if cat_features != []
101+
else num_features[0].size(0)
102+
)
103+
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
104+
105+
if self.cat_embeddings and cat_features is not None:
106+
cat_embeddings = [
107+
emb(cat_features[i]) for i, emb in enumerate(self.cat_embeddings)
108+
]
109+
cat_embeddings = torch.stack(cat_embeddings, dim=1)
110+
cat_embeddings = torch.squeeze(cat_embeddings, dim=2)
111+
if self.layer_norm_after_embedding:
112+
cat_embeddings = self.embedding_norm(cat_embeddings)
113+
else:
114+
cat_embeddings = None
115+
116+
if self.num_embeddings and num_features is not None:
117+
num_embeddings = [
118+
emb(num_features[i]) for i, emb in enumerate(self.num_embeddings)
119+
]
120+
num_embeddings = torch.stack(num_embeddings, dim=1)
121+
if self.layer_norm_after_embedding:
122+
num_embeddings = self.embedding_norm(num_embeddings)
123+
else:
124+
num_embeddings = None
125+
126+
if cat_embeddings is not None and num_embeddings is not None:
127+
x = torch.cat([cat_embeddings, num_embeddings], dim=1)
128+
elif cat_embeddings is not None:
129+
x = cat_embeddings
130+
elif num_embeddings is not None:
131+
x = num_embeddings
132+
else:
133+
raise ValueError("No features provided to the model.")
134+
135+
if self.use_cls:
136+
if self.cls_position == 0:
137+
x = torch.cat([cls_tokens, x], dim=1)
138+
elif self.cls_position == 1:
139+
x = torch.cat([x, cls_tokens], dim=1)
140+
else:
141+
raise ValueError(
142+
"Invalid cls_position value. It should be either 0 or 1."
143+
)
144+
145+
return x
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
5+
class PeriodicLinearEncodingLayer(nn.Module):
6+
def __init__(self, bins=10, learn_bins=True):
7+
super(PeriodicLinearEncodingLayer, self).__init__()
8+
self.bins = bins
9+
self.learn_bins = learn_bins
10+
11+
if self.learn_bins:
12+
# Learnable bin boundaries
13+
self.bin_boundaries = nn.Parameter(torch.linspace(0, 1, self.bins + 1))
14+
else:
15+
self.bin_boundaries = torch.linspace(-1, 1, self.bins + 1)
16+
17+
def forward(self, x):
18+
if self.learn_bins:
19+
# Ensure bin boundaries are sorted
20+
sorted_bins = torch.sort(self.bin_boundaries)[0]
21+
else:
22+
sorted_bins = self.bin_boundaries
23+
24+
# Initialize z with zeros
25+
z = torch.zeros(x.size(0), self.bins, device=x.device)
26+
27+
for t in range(1, self.bins + 1):
28+
b_t_1 = sorted_bins[t - 1]
29+
b_t = sorted_bins[t]
30+
mask1 = x < b_t_1
31+
mask2 = x >= b_t
32+
mask3 = (x >= b_t_1) & (x < b_t)
33+
34+
z[mask1.squeeze(), t - 1] = 0
35+
z[mask2.squeeze(), t - 1] = 1
36+
z[mask3.squeeze(), t - 1] = (x[mask3] - b_t_1) / (b_t - b_t_1)
37+
38+
return z

0 commit comments

Comments
 (0)