Skip to content

Commit c379a7a

Browse files
authored
Merge pull request #214 from basf/embeddings
Embeddings
2 parents 4a76db9 + ac27a1d commit c379a7a

43 files changed

Lines changed: 1327 additions & 1599 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/pr-tests.yml

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
name: PR Unit Tests
2+
3+
on:
4+
pull_request:
5+
branches:
6+
- develop
7+
- master # Add any other branches where you want to enforce tests
8+
9+
jobs:
10+
test:
11+
runs-on: ubuntu-latest
12+
13+
steps:
14+
- name: Checkout Repository
15+
uses: actions/checkout@v4
16+
17+
- name: Set up Python
18+
uses: actions/setup-python@v4
19+
with:
20+
python-version: "3.10" # Change this to match your setup
21+
22+
- name: Install Poetry
23+
run: |
24+
curl -sSL https://install.python-poetry.org | python3 -
25+
echo "$HOME/.local/bin" >> $GITHUB_PATH
26+
export PATH="$HOME/.local/bin:$PATH"
27+
28+
- name: Install Dependencies
29+
run: |
30+
python -m pip install --upgrade pip
31+
poetry install
32+
pip install pytest
33+
34+
- name: Install Package Locally
35+
run: |
36+
poetry build
37+
pip install dist/*.whl # Install the built package to fix "No module named 'mambular'"
38+
39+
- name: Run Unit Tests
40+
env:
41+
PYTHONPATH: ${{ github.workspace }} # Ensure the package is discoverable
42+
run: pytest tests/
43+
44+
- name: Verify Tests Passed
45+
if: ${{ success() }}
46+
run: echo "All tests passed! Pull request is allowed."
47+
48+
- name: Fail PR on Test Failure
49+
if: ${{ failure() }}
50+
run: exit 1 # This ensures the PR cannot be merged if tests fail

mambular/arch_utils/layer_utils/attention_utils.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import torch.nn as nn
66
import torch.nn.functional as F
77
from einops import rearrange
8-
from rotary_embedding_torch import RotaryEmbedding
98

109

1110
class GEGLU(nn.Module):
@@ -25,7 +24,7 @@ def FeedForward(dim, mult=4, dropout=0.0):
2524

2625

2726
class Attention(nn.Module):
28-
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, rotary=False):
27+
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
2928
super().__init__()
3029
inner_dim = dim_head * heads
3130
self.heads = heads
@@ -34,18 +33,13 @@ def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, rotary=False):
3433
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
3534
self.to_out = nn.Linear(inner_dim, dim, bias=False)
3635
self.dropout = nn.Dropout(dropout)
37-
self.rotary = rotary
3836
dim = np.int64(dim / 2)
39-
self.rotary_embedding = RotaryEmbedding(dim=dim)
4037

4138
def forward(self, x):
4239
h = self.heads
4340
x = self.norm(x)
4441
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
4542
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) # type: ignore
46-
if self.rotary:
47-
q = self.rotary_embedding.rotate_queries_or_keys(q)
48-
k = self.rotary_embedding.rotate_queries_or_keys(k)
4943
q = q * self.scale
5044

5145
sim = torch.einsum("b h i d, b h j d -> b h i j", q, k)
@@ -61,7 +55,7 @@ def forward(self, x):
6155

6256

6357
class Transformer(nn.Module):
64-
def __init__(self, dim, depth, heads, dim_head, attn_dropout, ff_dropout, rotary=False):
58+
def __init__(self, dim, depth, heads, dim_head, attn_dropout, ff_dropout):
6559
super().__init__()
6660
self.layers = nn.ModuleList([])
6761

@@ -74,7 +68,6 @@ def __init__(self, dim, depth, heads, dim_head, attn_dropout, ff_dropout, rotary
7468
heads=heads,
7569
dim_head=dim_head,
7670
dropout=attn_dropout,
77-
rotary=rotary,
7871
),
7972
FeedForward(dim, dropout=ff_dropout),
8073
]

mambular/arch_utils/layer_utils/embedding_layer.py

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
class EmbeddingLayer(nn.Module):
9-
def __init__(self, num_feature_info, cat_feature_info, config):
9+
def __init__(self, num_feature_info, cat_feature_info, emb_feature_info, config):
1010
"""Embedding layer that handles numerical and categorical embeddings.
1111
1212
Parameters
@@ -28,6 +28,7 @@ def __init__(self, num_feature_info, cat_feature_info, config):
2828
self.layer_norm_after_embedding = getattr(
2929
config, "layer_norm_after_embedding", False
3030
)
31+
self.embedding_projection = getattr(config, "embedding_projection", True)
3132
self.use_cls = getattr(config, "use_cls", False)
3233
self.cls_position = getattr(config, "cls_position", 0)
3334
self.embedding_dropout = (
@@ -100,6 +101,22 @@ def __init__(self, num_feature_info, cat_feature_info, config):
100101
]
101102
)
102103

104+
if len(emb_feature_info) >= 1:
105+
if self.embedding_projection:
106+
self.emb_embeddings = nn.ModuleList(
107+
[
108+
nn.Sequential(
109+
nn.Linear(
110+
feature_info["dimension"],
111+
self.d_model,
112+
bias=self.embedding_bias,
113+
),
114+
self.embedding_activation,
115+
)
116+
for feature_name, feature_info in emb_feature_info.items()
117+
]
118+
)
119+
103120
# Class token if required
104121
if self.use_cls:
105122
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.d_model))
@@ -108,15 +125,12 @@ def __init__(self, num_feature_info, cat_feature_info, config):
108125
if self.layer_norm_after_embedding:
109126
self.embedding_norm = nn.LayerNorm(self.d_model)
110127

111-
def forward(self, num_features=None, cat_features=None):
128+
def forward(self, num_features, cat_features, emb_features):
112129
"""Defines the forward pass of the model.
113130
114131
Parameters
115132
----------
116-
num_features : Tensor, optional
117-
Tensor containing the numerical features.
118-
cat_features : Tensor, optional
119-
Tensor containing the categorical features.
133+
data: tuple of lists of tensors
120134
121135
Returns
122136
-------
@@ -128,6 +142,7 @@ def forward(self, num_features=None, cat_features=None):
128142
ValueError
129143
If no features are provided to the model.
130144
"""
145+
num_embeddings, cat_embeddings, emb_embeddings = None, None, None
131146

132147
# Class token initialization
133148
if self.use_cls:
@@ -147,8 +162,6 @@ def forward(self, num_features=None, cat_features=None):
147162
cat_embeddings = torch.squeeze(cat_embeddings, dim=2)
148163
if self.layer_norm_after_embedding:
149164
cat_embeddings = self.embedding_norm(cat_embeddings)
150-
else:
151-
cat_embeddings = None
152165

153166
# Process numerical embeddings based on embedding_type
154167
if self.embedding_type == "plr":
@@ -161,25 +174,31 @@ def forward(self, num_features=None, cat_features=None):
161174
num_embeddings = self.num_embeddings(num_features)
162175
if self.layer_norm_after_embedding:
163176
num_embeddings = self.embedding_norm(num_embeddings)
164-
else:
165-
num_embeddings = None
166177
else:
167178
# For linear and ndt embeddings, handle each feature individually
168179
if self.num_embeddings and num_features is not None:
169180
num_embeddings = [emb(num_features[i]) for i, emb in enumerate(self.num_embeddings)] # type: ignore
170181
num_embeddings = torch.stack(num_embeddings, dim=1)
171182
if self.layer_norm_after_embedding:
172183
num_embeddings = self.embedding_norm(num_embeddings)
184+
185+
if emb_features != []:
186+
if self.embedding_projection:
187+
emb_embeddings = [
188+
emb(emb_features[i]) for i, emb in enumerate(self.emb_embeddings)
189+
]
190+
emb_embeddings = torch.stack(emb_embeddings, dim=1)
173191
else:
174-
num_embeddings = None
175-
176-
# Combine categorical and numerical embeddings
177-
if cat_embeddings is not None and num_embeddings is not None:
178-
x = torch.cat([cat_embeddings, num_embeddings], dim=1)
179-
elif cat_embeddings is not None:
180-
x = cat_embeddings
181-
elif num_embeddings is not None:
182-
x = num_embeddings
192+
emb_embeddings = torch.stack(emb_features, dim=1)
193+
if self.layer_norm_after_embedding:
194+
emb_embeddings = self.embedding_norm(emb_embeddings)
195+
196+
embeddings = [
197+
e for e in [cat_embeddings, num_embeddings, emb_embeddings] if e is not None
198+
]
199+
200+
if embeddings:
201+
x = torch.cat(embeddings, dim=1) if len(embeddings) > 1 else embeddings[0]
183202
else:
184203
raise ValueError("No features provided to the model.")
185204

mambular/base_models/basemodel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def pool_sequence(self, out):
223223
else:
224224
raise ValueError(f"Invalid pooling method: {self.hparams.pooling_method}")
225225

226-
def encode(self, num_features, cat_features):
226+
def encode(self, data):
227227
if not hasattr(self, "embedding_layer"):
228228
raise ValueError("The model does not have an embedding layer")
229229

@@ -237,7 +237,7 @@ def encode(self, num_features, cat_features):
237237
raise ValueError("The model does not generate contextualized embeddings")
238238

239239
# Get the actual layer and call it
240-
x = self.embedding_layer(num_features=num_features, cat_features=cat_features)
240+
x = self.embedding_layer(*data)
241241

242242
if getattr(self.hparams, "shuffle_embeddings", False):
243243
x = x[:, self.perm, :]

mambular/base_models/ft_transformer.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from ..arch_utils.transformer_utils import CustomTransformerEncoderLayer
77
from ..configs.fttransformer_config import DefaultFTTransformerConfig
88
from .basemodel import BaseModel
9+
import numpy as np
910

1011

1112
class FTTransformer(BaseModel):
@@ -52,22 +53,18 @@ class FTTransformer(BaseModel):
5253

5354
def __init__(
5455
self,
55-
cat_feature_info,
56-
num_feature_info,
56+
feature_information: tuple, # Expecting (num_feature_info, cat_feature_info, embedding_feature_info)
5757
num_classes=1,
5858
config: DefaultFTTransformerConfig = DefaultFTTransformerConfig(), # noqa: B008
5959
**kwargs,
6060
):
6161
super().__init__(config=config, **kwargs)
62-
self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])
62+
self.save_hyperparameters(ignore=["feature_information"])
6363
self.returns_ensemble = False
64-
self.cat_feature_info = cat_feature_info
65-
self.num_feature_info = num_feature_info
6664

6765
# embedding layer
6866
self.embedding_layer = EmbeddingLayer(
69-
num_feature_info=num_feature_info,
70-
cat_feature_info=cat_feature_info,
67+
*feature_information,
7168
config=config,
7269
)
7370

@@ -87,25 +84,23 @@ def __init__(
8784
)
8885

8986
# pooling
90-
n_inputs = len(num_feature_info) + len(cat_feature_info)
87+
n_inputs = np.sum([len(info) for info in feature_information])
9188
self.initialize_pooling_layers(config=config, n_inputs=n_inputs)
9289

93-
def forward(self, num_features, cat_features):
90+
def forward(self, *data):
9491
"""Defines the forward pass of the model.
9592
9693
Parameters
9794
----------
98-
num_features : Tensor
99-
Tensor containing the numerical features.
100-
cat_features : Tensor
101-
Tensor containing the categorical features.
95+
data : tuple
96+
Input tuple of tensors of num_features, cat_features, embeddings.
10297
10398
Returns
10499
-------
105100
Tensor
106101
The output predictions of the model.
107102
"""
108-
x = self.embedding_layer(num_features, cat_features)
103+
x = self.embedding_layer(*data)
109104

110105
x = self.encoder(x)
111106

0 commit comments

Comments
 (0)