Skip to content

Commit 5b4d52c

Browse files
committed
include classifier, regressor and lss for Tangos
1 parent 20908d9 commit 5b4d52c

2 files changed

Lines changed: 70 additions & 0 deletions

File tree

mambular/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,12 @@
2828
from .autoint import AutoIntClassifier, AutoIntLSS, AutoIntRegressor
2929
from .trompt import TromptClassifier, TromptLSS, TromptRegressor
3030
from .enode import ENODEClassifier, ENODELSS, ENODERegressor
31+
from .tangos import TangosClassifier, TangosLSS, TangosRegressor
3132

3233
__all__ = [
34+
"TangosClassifier",
35+
"TangosLSS",
36+
"TangosRegressor",
3337
"ENODEClassifier",
3438
"ENODELSS",
3539
"ENODERegressor",

mambular/models/tangos.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from ..base_models.tangos import Tangos
2+
from ..configs.tangos_config import DefaultTangosConfig
3+
from ..utils.docstring_generator import generate_docstring
4+
from .utils.sklearn_base_classifier import SklearnBaseClassifier
5+
from .utils.sklearn_base_lss import SklearnBaseLSS
6+
from .utils.sklearn_base_regressor import SklearnBaseRegressor
7+
8+
9+
class TangosRegressor(SklearnBaseRegressor):
10+
__doc__ = generate_docstring(
11+
DefaultTangosConfig,
12+
model_description="""
13+
Tangos regressor. This class extends the SklearnBaseRegressor class and uses the Tangos model
14+
with the default Tangos configuration.
15+
""",
16+
examples="""
17+
>>> from mambular.models import TangosRegressor
18+
>>> model = TangosRegressor(d_model=64, n_layers=8)
19+
>>> model.fit(X_train, y_train)
20+
>>> preds = model.predict(X_test)
21+
>>> model.evaluate(X_test, y_test)
22+
""",
23+
)
24+
25+
def __init__(self, **kwargs):
26+
super().__init__(model=Tangos, config=DefaultTangosConfig, **kwargs)
27+
28+
29+
class TangosClassifier(SklearnBaseClassifier):
30+
__doc__ = generate_docstring(
31+
DefaultTangosConfig,
32+
model_description="""
33+
Tangos classifier This class extends the SklearnBaseClassifier class and uses the Tangos model
34+
with the default Tangos configuration.
35+
""",
36+
examples="""
37+
>>> from mambular.models import TangosClassifier
38+
>>> model = TangosClassifier(d_model=64, n_layers=8)
39+
>>> model.fit(X_train, y_train)
40+
>>> preds = model.predict(X_test)
41+
>>> model.evaluate(X_test, y_test)
42+
""",
43+
)
44+
45+
def __init__(self, **kwargs):
46+
super().__init__(model=Tangos, config=DefaultTangosConfig, **kwargs)
47+
48+
49+
class TangosLSS(SklearnBaseLSS):
50+
__doc__ = generate_docstring(
51+
DefaultTangosConfig,
52+
model_description="""
53+
Tangos for distributional regression. This class extends the SklearnBaseLSS class and uses the Tangos model
54+
with the default Tangos configuration.
55+
""",
56+
examples="""
57+
>>> from mambular.models import TangosLSS
58+
>>> model = TangosLSS(d_model=64, n_layers=8)
59+
>>> model.fit(X_train, y_train, family='normal')
60+
>>> preds = model.predict(X_test)
61+
>>> model.evaluate(X_test, y_test)
62+
""",
63+
)
64+
65+
def __init__(self, **kwargs):
66+
super().__init__(model=Tangos, config=DefaultTangosConfig, **kwargs)

0 commit comments

Comments
 (0)