Skip to content

Commit 74088a7

Browse files
committed
tabR lazy imports
1 parent 30f7cc6 commit 74088a7

2 files changed

Lines changed: 48 additions & 28 deletions

File tree

mambular/base_models/tabr.py

Lines changed: 46 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99
from torch import Tensor
1010
import math
1111

12-
13-
1412
class TabR(BaseModel):
13+
delu = None
14+
faiss = None
15+
faiss_torch_utils = None
16+
1517
def __init__(
1618
self,
1719
feature_information: tuple,
@@ -22,6 +24,10 @@ def __init__(
2224
super().__init__(config=config, **kwargs)
2325
self.save_hyperparameters(ignore=["feature_information"])
2426

27+
# lazy import
28+
if TabR.delu or TabR.faiss or TabR.faiss_torch_utils is None:
29+
self._lazy_import_dependencies()
30+
2531
self.returns_ensemble = False
2632
self.uses_candidates = True
2733

@@ -86,12 +92,8 @@ def make_block(prenorm: bool) -> nn.Sequential:
8692

8793
# Retrieval Module: R
8894
self.normalization = Normalization(d_main) if mixer_normalization else None
89-
90-
# lazy import
91-
import delu
92-
import faiss
93-
import faiss.contrib.torch_utils
9495

96+
delu = TabR.delu
9597
self.label_encoder = (
9698
nn.Linear(1, d_main)
9799
if num_classes == 1
@@ -134,6 +136,31 @@ def reset_parameters(self):
134136
assert isinstance(self.label_encoder[0], nn.Embedding)
135137
nn.init.uniform_(self.label_encoder[0].weight, -1.0, 1.0) # type: ignore[code] # noqa: E501
136138

139+
def _lazy_import_dependencies(self):
140+
"""Lazily import external dependencies and store them as class attributes."""
141+
if TabR.delu is None:
142+
try:
143+
import delu
144+
TabR.delu = delu
145+
print("Successfully lazy imported delu dependency.")
146+
147+
except ImportError:
148+
raise ImportError("Failed to import delu module for TabR. Ensure all dependencies are installed\n"
149+
"You can install faiss running 'pip install delu'.") from None
150+
151+
if TabR.faiss is None:
152+
try:
153+
import faiss
154+
import faiss.contrib.torch_utils
155+
156+
TabR.faiss = faiss
157+
TabR.faiss_torch_utils = faiss.contrib.torch_utils
158+
print("Successfully lazy imported faiss dependency")
159+
160+
except ImportError as e:
161+
raise ImportError("Failed to import a required module for TabR. Ensure all dependencies are installed\n"
162+
"You can install delu by running 'pip install delu'.") from None
163+
137164
def _encode(
138165
self,
139166
a
@@ -210,7 +237,8 @@ def train_with_candidates(
210237
else torch.cat(
211238
[
212239
self._encode(x)[1] # normalized x
213-
for x in delu.iter_batches(
240+
# for x in delu.iter_batches(
241+
for x in TabR.delu.iter_batches(
214242
candidate_x,
215243
self.candidate_encoding_batch_size
216244
)
@@ -229,12 +257,12 @@ def train_with_candidates(
229257
# initializing the search index
230258
if self.search_index is None:
231259
self.search_index = (
232-
faiss.GpuIndexFlatL2(
233-
faiss.StandardGpuResources(),
260+
TabR.faiss.GpuIndexFlatL2(
261+
TabR.faiss.StandardGpuResources(),
234262
d_main
235263
)
236264
if device.type == 'cuda'
237-
else faiss.IndexFlatL2(d_main)
265+
else TabR.faiss.IndexFlatL2(d_main)
238266
)
239267
# Updating the index is much faster than creating a new one.
240268
self.search_index.reset()
@@ -318,7 +346,7 @@ def validate_with_candidates(
318346
else torch.cat(
319347
[
320348
self._encode(x)[1] # normalized x
321-
for x in delu.iter_batches(
349+
for x in TabR.delu.iter_batches(
322350
candidate_x,
323351
self.candidate_encoding_batch_size
324352
)
@@ -333,19 +361,15 @@ def validate_with_candidates(
333361
device = k.device
334362
context_size = self.context_size
335363

336-
# lazy import
337-
import faiss
338-
import faiss.contrib.torch_utils
339364
if self.search_index is None:
340365
self.search_index = (
341-
faiss.GpuIndexFlatL2(faiss.StandardGpuResources(), d_main)
366+
TabR.faiss.GpuIndexFlatL2(TabR.faiss.StandardGpuResources(), d_main)
342367
if device.type == 'cuda'
343-
else faiss.IndexFlatL2(d_main)
368+
else TabR.faiss.IndexFlatL2(d_main)
344369
)
345370

346371
# Updating the index is much faster than creating a new one.
347372
self.search_index.reset()
348-
# print(candidate_k)
349373
self.search_index.add(candidate_k.to(torch.float32)) # type: ignore[code]
350374
distances: Tensor
351375
context_idx: Tensor
@@ -407,7 +431,7 @@ def predict_with_candidates(
407431
else torch.cat(
408432
[
409433
self._encode(x)[1] # normalized x
410-
for x in delu.iter_batches(
434+
for x in TabR.delu.iter_batches(
411435
candidate_x,
412436
self.candidate_encoding_batch_size
413437
)
@@ -422,19 +446,16 @@ def predict_with_candidates(
422446
device = k.device
423447
context_size = self.context_size
424448

425-
# lazy import
426-
import faiss
427-
import faiss.contrib.torch_utils
428449
if self.search_index is None:
429450
self.search_index = (
430-
faiss.GpuIndexFlatL2(faiss.StandardGpuResources(), d_main)
451+
TabR.faiss.GpuIndexFlatL2(TabR.faiss.StandardGpuResources(), d_main)
431452
if device.type == 'cuda'
432-
else faiss.IndexFlatL2(d_main)
453+
else TabR.faiss.IndexFlatL2(d_main)
433454
)
455+
434456

435457
# Updating the index is much faster than creating a new one.
436458
self.search_index.reset()
437-
# print(candidate_k)
438459
self.search_index.add(candidate_k.to(torch.float32)) # type: ignore[code]
439460
distances: Tensor
440461
context_idx: Tensor

mambular/configs/tabr_config.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@ class DefaultTabRConfig(BaseConfig):
3232
context_size:int=96
3333

3434
# Embedding Parameters
35-
emebedding_type: str = "plr"
35+
embedding_type: str = "plr"
3636
plr_lite: bool = True
3737
n_frequencies: int = 75
38-
frequencies_init_scale: float = 0.045
39-
38+
frequencies_init_scale: float = 0.045

0 commit comments

Comments
 (0)