99from torch import Tensor
1010import math
1111
12-
13-
1412class TabR (BaseModel ):
13+ delu = None
14+ faiss = None
15+ faiss_torch_utils = None
16+
1517 def __init__ (
1618 self ,
1719 feature_information : tuple ,
@@ -22,6 +24,10 @@ def __init__(
2224 super ().__init__ (config = config , ** kwargs )
2325 self .save_hyperparameters (ignore = ["feature_information" ])
2426
27+ # lazy import
28+ if TabR .delu or TabR .faiss or TabR .faiss_torch_utils is None :
29+ self ._lazy_import_dependencies ()
30+
2531 self .returns_ensemble = False
2632 self .uses_candidates = True
2733
@@ -86,12 +92,8 @@ def make_block(prenorm: bool) -> nn.Sequential:
8692
8793 # Retrieval Module: R
8894 self .normalization = Normalization (d_main ) if mixer_normalization else None
89-
90- # lazy import
91- import delu
92- import faiss
93- import faiss .contrib .torch_utils
9495
96+ delu = TabR .delu
9597 self .label_encoder = (
9698 nn .Linear (1 , d_main )
9799 if num_classes == 1
@@ -134,6 +136,31 @@ def reset_parameters(self):
134136 assert isinstance (self .label_encoder [0 ], nn .Embedding )
135137 nn .init .uniform_ (self .label_encoder [0 ].weight , - 1.0 , 1.0 ) # type: ignore[code] # noqa: E501
136138
139+ def _lazy_import_dependencies (self ):
140+ """Lazily import external dependencies and store them as class attributes."""
141+ if TabR .delu is None :
142+ try :
143+ import delu
144+ TabR .delu = delu
145+ print ("Successfully lazy imported delu dependency." )
146+
147+ except ImportError :
148+ raise ImportError ("Failed to import delu module for TabR. Ensure all dependencies are installed\n "
149+ "You can install faiss running 'pip install delu'." ) from None
150+
151+ if TabR .faiss is None :
152+ try :
153+ import faiss
154+ import faiss .contrib .torch_utils
155+
156+ TabR .faiss = faiss
157+ TabR .faiss_torch_utils = faiss .contrib .torch_utils
158+ print ("Successfully lazy imported faiss dependency" )
159+
160+ except ImportError as e :
161+ raise ImportError ("Failed to import a required module for TabR. Ensure all dependencies are installed\n "
162+ "You can install delu by running 'pip install delu'." ) from None
163+
137164 def _encode (
138165 self ,
139166 a
@@ -210,7 +237,8 @@ def train_with_candidates(
210237 else torch .cat (
211238 [
212239 self ._encode (x )[1 ] # normalized x
213- for x in delu .iter_batches (
240+ # for x in delu.iter_batches(
241+ for x in TabR .delu .iter_batches (
214242 candidate_x ,
215243 self .candidate_encoding_batch_size
216244 )
@@ -229,12 +257,12 @@ def train_with_candidates(
229257 # initializing the search index
230258 if self .search_index is None :
231259 self .search_index = (
232- faiss .GpuIndexFlatL2 (
233- faiss .StandardGpuResources (),
260+ TabR . faiss .GpuIndexFlatL2 (
261+ TabR . faiss .StandardGpuResources (),
234262 d_main
235263 )
236264 if device .type == 'cuda'
237- else faiss .IndexFlatL2 (d_main )
265+ else TabR . faiss .IndexFlatL2 (d_main )
238266 )
239267 # Updating the index is much faster than creating a new one.
240268 self .search_index .reset ()
@@ -318,7 +346,7 @@ def validate_with_candidates(
318346 else torch .cat (
319347 [
320348 self ._encode (x )[1 ] # normalized x
321- for x in delu .iter_batches (
349+ for x in TabR . delu .iter_batches (
322350 candidate_x ,
323351 self .candidate_encoding_batch_size
324352 )
@@ -333,19 +361,15 @@ def validate_with_candidates(
333361 device = k .device
334362 context_size = self .context_size
335363
336- # lazy import
337- import faiss
338- import faiss .contrib .torch_utils
339364 if self .search_index is None :
340365 self .search_index = (
341- faiss .GpuIndexFlatL2 (faiss .StandardGpuResources (), d_main )
366+ TabR . faiss .GpuIndexFlatL2 (TabR . faiss .StandardGpuResources (), d_main )
342367 if device .type == 'cuda'
343- else faiss .IndexFlatL2 (d_main )
368+ else TabR . faiss .IndexFlatL2 (d_main )
344369 )
345370
346371 # Updating the index is much faster than creating a new one.
347372 self .search_index .reset ()
348- # print(candidate_k)
349373 self .search_index .add (candidate_k .to (torch .float32 )) # type: ignore[code]
350374 distances : Tensor
351375 context_idx : Tensor
@@ -407,7 +431,7 @@ def predict_with_candidates(
407431 else torch .cat (
408432 [
409433 self ._encode (x )[1 ] # normalized x
410- for x in delu .iter_batches (
434+ for x in TabR . delu .iter_batches (
411435 candidate_x ,
412436 self .candidate_encoding_batch_size
413437 )
@@ -422,19 +446,16 @@ def predict_with_candidates(
422446 device = k .device
423447 context_size = self .context_size
424448
425- # lazy import
426- import faiss
427- import faiss .contrib .torch_utils
428449 if self .search_index is None :
429450 self .search_index = (
430- faiss .GpuIndexFlatL2 (faiss .StandardGpuResources (), d_main )
451+ TabR . faiss .GpuIndexFlatL2 (TabR . faiss .StandardGpuResources (), d_main )
431452 if device .type == 'cuda'
432- else faiss .IndexFlatL2 (d_main )
453+ else TabR . faiss .IndexFlatL2 (d_main )
433454 )
455+
434456
435457 # Updating the index is much faster than creating a new one.
436458 self .search_index .reset ()
437- # print(candidate_k)
438459 self .search_index .add (candidate_k .to (torch .float32 )) # type: ignore[code]
439460 distances : Tensor
440461 context_idx : Tensor
0 commit comments