Skip to content

Commit d320552

Browse files
authored
Merge pull request #238 from basf/refactorization
Refactorization
2 parents 628182f + c75ef88 commit d320552

10 files changed

Lines changed: 550 additions & 4 deletions

File tree

mambular/__version__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,6 @@
1616
#
1717

1818
# The following line *must* be the last in the module, exactly as formatted:
19-
__version__ = "1.2.1"
19+
20+
__version__ = "1.3.0"
21+

mambular/arch_utils/enode_utils.py

Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from mambular.arch_utils.layer_utils.sparsemax import sparsemax, sparsemoid
5+
from .data_aware_initialization import ModuleWithInit
6+
from .numpy_utils import check_numpy
7+
import numpy as np
8+
from warnings import warn
9+
10+
11+
class ODSTE(ModuleWithInit):
12+
13+
def __init__(
14+
self,
15+
in_features, # J (number of features)
16+
num_trees,
17+
embed_dim, # D (embedding dimension per feature)
18+
depth=6,
19+
tree_dim=1,
20+
flatten_output=True,
21+
choice_function=sparsemax,
22+
bin_function=sparsemoid,
23+
initialize_response_=nn.init.normal_,
24+
initialize_selection_logits_=nn.init.uniform_,
25+
threshold_init_beta=1.0,
26+
threshold_init_cutoff=1.0,
27+
):
28+
"""Oblivious Differentiable Sparsemax Trees (ODST) with Feature & Embedding Splitting."""
29+
super().__init__()
30+
self.depth, self.num_trees, self.tree_dim, self.flatten_output = (
31+
depth,
32+
num_trees,
33+
tree_dim,
34+
flatten_output,
35+
)
36+
self.choice_function, self.bin_function = choice_function, bin_function
37+
self.in_features, self.embed_dim = in_features, embed_dim
38+
self.threshold_init_beta, self.threshold_init_cutoff = (
39+
threshold_init_beta,
40+
threshold_init_cutoff,
41+
)
42+
43+
# Response values for each leaf
44+
self.response = nn.Parameter(
45+
torch.zeros([num_trees, tree_dim, embed_dim, 2**depth]), requires_grad=True
46+
)
47+
48+
initialize_response_(self.response)
49+
50+
# Feature selection logits (choose J)
51+
self.feature_selection_logits = nn.Parameter(
52+
torch.zeros([num_trees, depth, in_features]), requires_grad=True
53+
)
54+
initialize_selection_logits_(self.feature_selection_logits)
55+
56+
# Embedding selection logits (choose D within J)
57+
self.embedding_selection_logits = nn.Parameter(
58+
torch.randn([num_trees, depth, in_features, embed_dim])
59+
)
60+
61+
# Thresholds & temperatures (random initialization)
62+
self.feature_thresholds = nn.Parameter(torch.randn([num_trees, depth]))
63+
self.log_temperatures = nn.Parameter(torch.randn([num_trees, depth]))
64+
65+
# Binary code mappings
66+
with torch.no_grad():
67+
indices = torch.arange(2**self.depth)
68+
offsets = 2 ** torch.arange(self.depth)
69+
bin_codes = (indices.view(1, -1) // offsets.view(-1, 1) % 2).to(
70+
torch.float32
71+
)
72+
bin_codes_1hot = torch.stack([bin_codes, 1.0 - bin_codes], dim=-1)
73+
self.bin_codes_1hot = nn.Parameter(bin_codes_1hot, requires_grad=False)
74+
75+
def initialize(self, x, eps=1e-6):
76+
"""Data-aware initialization of thresholds and log-temperatures based on input data.
77+
78+
Parameters
79+
----------
80+
x : torch.Tensor
81+
Input tensor of shape [batch_size, in_features, embed_dim] used for threshold initialization.
82+
eps : float, optional
83+
Small value added to avoid log(0) errors in temperature initialization. Default is 1e-6.
84+
"""
85+
if len(x.shape) != 3:
86+
raise ValueError("Input tensor must have shape (batch_size, J, D)")
87+
88+
if x.shape[0] < 1000:
89+
warn(
90+
"Data-aware initialization is performed on less than 1000 data points. This may cause instability."
91+
"To avoid potential problems, run this model on a data batch with at least 1000 data samples."
92+
"You can do so manually before training. Use with torch.no_grad() for memory efficiency."
93+
)
94+
95+
with torch.no_grad():
96+
# Select features (J)
97+
feature_selectors = self.choice_function(
98+
self.feature_selection_logits, dim=-1
99+
)
100+
# feature_selectors shape: (num_trees, depth, J)
101+
102+
selected_features = torch.einsum("bjd,ntj->bntd", x, feature_selectors)
103+
# selected_features shape: (B, num_trees, depth, D)
104+
105+
# Select embeddings (D)
106+
embedding_selectors = self.choice_function(
107+
self.embedding_selection_logits, dim=-1
108+
)
109+
# embedding_selectors shape: (num_trees, depth, J, D)
110+
111+
selected_embeddings = torch.einsum(
112+
"bntd,ntjd->bntd", selected_features, embedding_selectors
113+
)
114+
# selected_embeddings shape: (B, num_trees, depth, D)
115+
116+
# Initialize thresholds using percentiles from the data
117+
percentiles_q = 100 * np.random.beta(
118+
self.threshold_init_beta,
119+
self.threshold_init_beta,
120+
size=[self.num_trees, self.depth],
121+
)
122+
123+
reshaped_embeddings = selected_embeddings.permute(1, 2, 0, 3).reshape(
124+
self.num_trees * self.depth, -1
125+
)
126+
self.feature_thresholds.data[...] = torch.as_tensor(
127+
list(
128+
map(
129+
np.percentile,
130+
check_numpy(reshaped_embeddings), # Now correctly 2D
131+
percentiles_q.flatten(),
132+
)
133+
),
134+
dtype=selected_embeddings.dtype,
135+
device=selected_embeddings.device,
136+
).view(self.num_trees, self.depth)
137+
138+
# Initialize temperatures based on the threshold differences
139+
temperatures = np.percentile(
140+
check_numpy(
141+
abs(selected_embeddings - self.feature_thresholds.unsqueeze(-1))
142+
),
143+
q=100 * min(1.0, self.threshold_init_cutoff),
144+
axis=0,
145+
)
146+
147+
# Scale temperatures based on the cutoff
148+
temperatures /= max(1.0, self.threshold_init_cutoff)
149+
150+
self.log_temperatures.data[...] = torch.log(
151+
torch.as_tensor(
152+
temperatures.mean(-1),
153+
dtype=selected_embeddings.dtype,
154+
device=selected_embeddings.device,
155+
)
156+
+ eps
157+
)
158+
159+
def forward(self, x):
160+
if len(x.shape) != 3:
161+
raise ValueError("Input tensor must have shape (batch_size, J, D)")
162+
163+
# Select feature (J) and embedding dimension (D) separately
164+
feature_selectors = self.choice_function(
165+
self.feature_selection_logits, dim=-1
166+
) # [num_trees, depth, J]
167+
168+
embedding_selectors = self.choice_function(
169+
self.embedding_selection_logits, dim=-1
170+
) # [num_trees, depth, J, D]
171+
172+
# Select features (J) first
173+
selected_features = torch.einsum("bjd,ntj->bntd", x, feature_selectors)
174+
175+
# Select embeddings (D) within selected features
176+
selected_embeddings = torch.einsum(
177+
"bntd,ntjd->bntd", selected_features, embedding_selectors
178+
)
179+
180+
# Compute threshold logits
181+
threshold_logits = (
182+
selected_embeddings - self.feature_thresholds.unsqueeze(0).unsqueeze(-1)
183+
) * torch.exp(-self.log_temperatures.unsqueeze(0).unsqueeze(-1))
184+
185+
threshold_logits = torch.stack([-threshold_logits, threshold_logits], dim=-1)
186+
187+
# Compute binary decisions
188+
bins = self.bin_function(threshold_logits)
189+
190+
bin_matches = torch.einsum("bntds,tcs->bntdc", bins, self.bin_codes_1hot)
191+
192+
response_weights = torch.prod(bin_matches, dim=2)
193+
194+
# Compute final response
195+
response = torch.einsum("bnds,ncds->bnd", response_weights, self.response)
196+
return response
197+
198+
def __repr__(self):
199+
return f"{self.__class__.__name__}(in_features={self.in_features}, embed_dim={self.embed_dim}, num_trees={self.num_trees}, depth={self.depth}, tree_dim={self.tree_dim}, flatten_output={self.flatten_output})"
200+
201+
202+
class DenseBlock(nn.Module):
203+
"""DenseBlock that sequentially stacks attention layers and `Module` layers (e.g., ODSTE)
204+
with feature and embedding-aware splits.
205+
206+
Parameters
207+
----------
208+
input_dim : int
209+
Number of features (J) in the input.
210+
embed_dim : int
211+
Embedding dimension per feature (D).
212+
layer_dim : int
213+
Dimensionality of each ODSTE layer.
214+
num_layers : int
215+
Number of layers to stack in the block.
216+
tree_dim : int, optional
217+
Number of output channels from each tree. Default is 1.
218+
max_features : int, optional
219+
Maximum number of features for expansion. Default is None.
220+
input_dropout : float, optional
221+
Dropout rate applied to inputs during training. Default is 0.0.
222+
flatten_output : bool, optional
223+
If True, flattens the output along the tree dimension. Default is True.
224+
Module : nn.Module, optional
225+
Module class to use for each layer in the block. Default is `ODSTE`.
226+
**kwargs : dict
227+
Additional keyword arguments for `Module` instances.
228+
"""
229+
230+
def __init__(
231+
self,
232+
input_dim,
233+
embed_dim,
234+
layer_dim,
235+
num_layers,
236+
tree_dim=1,
237+
max_features=None,
238+
input_dropout=0.0,
239+
flatten_output=True,
240+
Module=ODSTE,
241+
**kwargs,
242+
):
243+
super().__init__()
244+
self.num_layers = num_layers
245+
self.layer_dim = layer_dim
246+
self.tree_dim = tree_dim
247+
self.max_features = max_features
248+
self.input_dropout = input_dropout
249+
self.flatten_output = flatten_output
250+
251+
self.attention_layers = nn.ModuleList()
252+
self.odste_layers = nn.ModuleList()
253+
254+
for _ in range(num_layers):
255+
# self.attention_layers.append(
256+
# nn.MultiheadAttention(
257+
# embed_dim=embed_dim, num_heads=1, batch_first=True
258+
# )
259+
# )
260+
self.odste_layers.append(
261+
Module(
262+
in_features=input_dim,
263+
embed_dim=embed_dim,
264+
num_trees=layer_dim,
265+
tree_dim=tree_dim,
266+
flatten_output=True,
267+
**kwargs,
268+
)
269+
)
270+
input_dim = min(
271+
input_dim + layer_dim * tree_dim, max_features or float("inf")
272+
)
273+
274+
def forward(self, x):
275+
"""Forward pass through the DenseBlock.
276+
277+
Parameters
278+
----------
279+
x : torch.Tensor
280+
Input tensor of shape [batch_size, J, D].
281+
282+
Returns
283+
-------
284+
torch.Tensor
285+
Output tensor with expanded features.
286+
"""
287+
initial_features = x.shape[1] # J (num features)
288+
289+
for odste_layer in self.odste_layers:
290+
# x, _ = attn_layer(x, x, x) # Apply attention
291+
292+
if self.max_features is not None:
293+
tail_features = min(self.max_features, x.shape[1]) - initial_features
294+
if tail_features > 0:
295+
x = torch.cat(
296+
[x[:, :initial_features, :], x[:, -tail_features:, :]], dim=1
297+
)
298+
299+
if self.training and self.input_dropout:
300+
x = F.dropout(x, self.input_dropout)
301+
302+
h = odste_layer(x) # Apply ODSTE layer
303+
x = torch.cat([x, h], dim=1) # Concatenate new features
304+
305+
return x

mambular/base_models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
from .tabularnn import TabulaRNN
1313
from .autoint import AutoInt
1414
from .trompt import Trompt
15+
from .enode import ENODE
1516

1617
__all__ = [
18+
"ENODE",
1719
"Trompt",
1820
"AutoInt",
1921
"MLP",

0 commit comments

Comments
 (0)