Skip to content

Commit 9acb5b1

Browse files
committed
adapt imports
1 parent f085907 commit 9acb5b1

16 files changed

Lines changed: 81 additions & 2453 deletions

mambular/models/__init__.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,29 @@
1-
from .fttransformer import FTTransformerClassifier, FTTransformerLSS, FTTransformerRegressor
1+
from .fttransformer import (
2+
FTTransformerClassifier,
3+
FTTransformerLSS,
4+
FTTransformerRegressor,
5+
)
26
from .mambatab import MambaTabClassifier, MambaTabLSS, MambaTabRegressor
3-
from .mambattention import MambAttentionClassifier, MambAttentionLSS, MambAttentionRegressor
7+
from .mambattention import (
8+
MambAttentionClassifier,
9+
MambAttentionLSS,
10+
MambAttentionRegressor,
11+
)
412
from .mambular import MambularClassifier, MambularLSS, MambularRegressor
513
from .mlp import MLPLSS, MLPClassifier, MLPRegressor
614
from .ndtf import NDTFLSS, NDTFClassifier, NDTFRegressor
715
from .node import NODELSS, NODEClassifier, NODERegressor
816
from .resnet import ResNetClassifier, ResNetLSS, ResNetRegressor
917
from .saint import SAINTLSS, SAINTClassifier, SAINTRegressor
10-
from .sklearn_base_classifier import SklearnBaseClassifier
11-
from .sklearn_base_lss import SklearnBaseLSS
12-
from .sklearn_base_regressor import SklearnBaseRegressor
18+
from .utils.sklearn_base_classifier import SklearnBaseClassifier
19+
from .utils.sklearn_base_lss import SklearnBaseLSS
20+
from .utils.sklearn_base_regressor import SklearnBaseRegressor
1321
from .tabm import TabMClassifier, TabMLSS, TabMRegressor
14-
from .tabtransformer import TabTransformerClassifier, TabTransformerLSS, TabTransformerRegressor
22+
from .tabtransformer import (
23+
TabTransformerClassifier,
24+
TabTransformerLSS,
25+
TabTransformerRegressor,
26+
)
1527
from .tabularnn import TabulaRNNClassifier, TabulaRNNLSS, TabulaRNNRegressor
1628

1729
__all__ = [

mambular/models/fttransformer.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from ..base_models.ft_transformer import FTTransformer
22
from ..configs.fttransformer_config import DefaultFTTransformerConfig
33
from ..utils.docstring_generator import generate_docstring
4-
from .sklearn_base_classifier import SklearnBaseClassifier
5-
from .sklearn_base_lss import SklearnBaseLSS
6-
from .sklearn_base_regressor import SklearnBaseRegressor
4+
from .utils.sklearn_base_classifier import SklearnBaseClassifier
5+
from .utils.sklearn_base_lss import SklearnBaseLSS
6+
from .utils.sklearn_base_regressor import SklearnBaseRegressor
77

88

99
class FTTransformerRegressor(SklearnBaseRegressor):
@@ -24,7 +24,9 @@ class and uses the FTTransformer model with the default FTTransformer
2424
)
2525

2626
def __init__(self, **kwargs):
27-
super().__init__(model=FTTransformer, config=DefaultFTTransformerConfig, **kwargs)
27+
super().__init__(
28+
model=FTTransformer, config=DefaultFTTransformerConfig, **kwargs
29+
)
2830

2931

3032
class FTTransformerClassifier(SklearnBaseClassifier):
@@ -42,7 +44,9 @@ class FTTransformerClassifier(SklearnBaseClassifier):
4244
)
4345

4446
def __init__(self, **kwargs):
45-
super().__init__(model=FTTransformer, config=DefaultFTTransformerConfig, **kwargs)
47+
super().__init__(
48+
model=FTTransformer, config=DefaultFTTransformerConfig, **kwargs
49+
)
4650

4751

4852
class FTTransformerLSS(SklearnBaseLSS):
@@ -61,4 +65,6 @@ class FTTransformerLSS(SklearnBaseLSS):
6165
)
6266

6367
def __init__(self, **kwargs):
64-
super().__init__(model=FTTransformer, config=DefaultFTTransformerConfig, **kwargs)
68+
super().__init__(
69+
model=FTTransformer, config=DefaultFTTransformerConfig, **kwargs
70+
)

mambular/models/mambatab.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from ..base_models.mambatab import MambaTab
22
from ..configs.mambatab_config import DefaultMambaTabConfig
33
from ..utils.docstring_generator import generate_docstring
4-
from .sklearn_base_classifier import SklearnBaseClassifier
5-
from .sklearn_base_lss import SklearnBaseLSS
6-
from .sklearn_base_regressor import SklearnBaseRegressor
4+
from .utils.sklearn_base_classifier import SklearnBaseClassifier
5+
from .utils.sklearn_base_lss import SklearnBaseLSS
6+
from .utils.sklearn_base_regressor import SklearnBaseRegressor
77

88

99
class MambaTabRegressor(SklearnBaseRegressor):

mambular/models/mambattention.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from ..base_models.mambattn import MambAttention
22
from ..configs.mambattention_config import DefaultMambAttentionConfig
33
from ..utils.docstring_generator import generate_docstring
4-
from .sklearn_base_classifier import SklearnBaseClassifier
5-
from .sklearn_base_lss import SklearnBaseLSS
6-
from .sklearn_base_regressor import SklearnBaseRegressor
4+
from .utils.sklearn_base_classifier import SklearnBaseClassifier
5+
from .utils.sklearn_base_lss import SklearnBaseLSS
6+
from .utils.sklearn_base_regressor import SklearnBaseRegressor
77

88

99
class MambAttentionRegressor(SklearnBaseRegressor):
@@ -23,7 +23,9 @@ class MambAttentionRegressor(SklearnBaseRegressor):
2323
)
2424

2525
def __init__(self, **kwargs):
26-
super().__init__(model=MambAttention, config=DefaultMambAttentionConfig, **kwargs)
26+
super().__init__(
27+
model=MambAttention, config=DefaultMambAttentionConfig, **kwargs
28+
)
2729

2830

2931
class MambAttentionClassifier(SklearnBaseClassifier):
@@ -43,7 +45,9 @@ class MambAttentionClassifier(SklearnBaseClassifier):
4345
)
4446

4547
def __init__(self, **kwargs):
46-
super().__init__(model=MambAttention, config=DefaultMambAttentionConfig, **kwargs)
48+
super().__init__(
49+
model=MambAttention, config=DefaultMambAttentionConfig, **kwargs
50+
)
4751

4852

4953
class MambAttentionLSS(SklearnBaseLSS):
@@ -63,4 +67,6 @@ class MambAttentionLSS(SklearnBaseLSS):
6367
)
6468

6569
def __init__(self, **kwargs):
66-
super().__init__(model=MambAttention, config=DefaultMambAttentionConfig, **kwargs)
70+
super().__init__(
71+
model=MambAttention, config=DefaultMambAttentionConfig, **kwargs
72+
)

mambular/models/mambular.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from ..base_models.mambular import Mambular
22
from ..configs.mambular_config import DefaultMambularConfig
33
from ..utils.docstring_generator import generate_docstring
4-
from .sklearn_base_classifier import SklearnBaseClassifier
5-
from .sklearn_base_lss import SklearnBaseLSS
6-
from .sklearn_base_regressor import SklearnBaseRegressor
4+
from .utils.sklearn_base_classifier import SklearnBaseClassifier
5+
from .utils.sklearn_base_lss import SklearnBaseLSS
6+
from .utils.sklearn_base_regressor import SklearnBaseRegressor
77

88

99
class MambularRegressor(SklearnBaseRegressor):

mambular/models/mlp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from ..base_models.mlp import MLP
22
from ..configs.mlp_config import DefaultMLPConfig
33
from ..utils.docstring_generator import generate_docstring
4-
from .sklearn_base_classifier import SklearnBaseClassifier
5-
from .sklearn_base_lss import SklearnBaseLSS
6-
from .sklearn_base_regressor import SklearnBaseRegressor
4+
from .utils.sklearn_base_classifier import SklearnBaseClassifier
5+
from .utils.sklearn_base_lss import SklearnBaseLSS
6+
from .utils.sklearn_base_regressor import SklearnBaseRegressor
77

88

99
class MLPRegressor(SklearnBaseRegressor):

mambular/models/ndtf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from ..base_models.ndtf import NDTF
22
from ..configs.ndtf_config import DefaultNDTFConfig
33
from ..utils.docstring_generator import generate_docstring
4-
from .sklearn_base_classifier import SklearnBaseClassifier
5-
from .sklearn_base_lss import SklearnBaseLSS
6-
from .sklearn_base_regressor import SklearnBaseRegressor
4+
from .utils.sklearn_base_classifier import SklearnBaseClassifier
5+
from .utils.sklearn_base_lss import SklearnBaseLSS
6+
from .utils.sklearn_base_regressor import SklearnBaseRegressor
77

88

99
class NDTFRegressor(SklearnBaseRegressor):

mambular/models/node.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from ..base_models.node import NODE
22
from ..configs.node_config import DefaultNODEConfig
33
from ..utils.docstring_generator import generate_docstring
4-
from .sklearn_base_classifier import SklearnBaseClassifier
5-
from .sklearn_base_lss import SklearnBaseLSS
6-
from .sklearn_base_regressor import SklearnBaseRegressor
4+
from .utils.sklearn_base_classifier import SklearnBaseClassifier
5+
from .utils.sklearn_base_lss import SklearnBaseLSS
6+
from .utils.sklearn_base_regressor import SklearnBaseRegressor
77

88

99
class NODERegressor(SklearnBaseRegressor):

mambular/models/resnet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from ..base_models.resnet import ResNet
22
from ..configs.resnet_config import DefaultResNetConfig
33
from ..utils.docstring_generator import generate_docstring
4-
from .sklearn_base_classifier import SklearnBaseClassifier
5-
from .sklearn_base_lss import SklearnBaseLSS
6-
from .sklearn_base_regressor import SklearnBaseRegressor
4+
from .utils.sklearn_base_classifier import SklearnBaseClassifier
5+
from .utils.sklearn_base_lss import SklearnBaseLSS
6+
from .utils.sklearn_base_regressor import SklearnBaseRegressor
77

88

99
class ResNetRegressor(SklearnBaseRegressor):

mambular/models/saint.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from ..base_models.saint import SAINT
22
from ..configs.saint_config import DefaultSAINTConfig
33
from ..utils.docstring_generator import generate_docstring
4-
from .sklearn_base_classifier import SklearnBaseClassifier
5-
from .sklearn_base_lss import SklearnBaseLSS
6-
from .sklearn_base_regressor import SklearnBaseRegressor
4+
from .utils.sklearn_base_classifier import SklearnBaseClassifier
5+
from .utils.sklearn_base_lss import SklearnBaseLSS
6+
from .utils.sklearn_base_regressor import SklearnBaseRegressor
77

88

99
class SAINTRegressor(SklearnBaseRegressor):

0 commit comments

Comments
 (0)