Skip to content

Commit febf165

Browse files
committed
include unit tests
1 parent e60dd80 commit febf165

7 files changed

Lines changed: 377 additions & 711 deletions

File tree

tests/test_base.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
import pytest
2+
import inspect
3+
import torch
4+
import os
5+
import importlib
6+
from mambular.base_models.basemodel import BaseModel
7+
8+
# Paths for models and configs
9+
MODEL_MODULE_PATH = "mambular.base_models"
10+
CONFIG_MODULE_PATH = "mambular.configs"
11+
12+
# Discover all models
13+
model_classes = []
14+
for filename in os.listdir(os.path.dirname(__file__) + "/../mambular/base_models"):
15+
if filename.endswith(".py") and filename not in [
16+
"__init__.py",
17+
"basemodel.py",
18+
"lightning_wrapper.py",
19+
"bayesian_tabm.py",
20+
]:
21+
module_name = f"{MODEL_MODULE_PATH}.{filename[:-3]}"
22+
module = importlib.import_module(module_name)
23+
24+
for name, obj in inspect.getmembers(module, inspect.isclass):
25+
if issubclass(obj, BaseModel) and obj is not BaseModel:
26+
model_classes.append(obj)
27+
28+
29+
def get_model_config(model_class):
30+
"""Dynamically load the correct config class for each model."""
31+
model_name = model_class.__name__ # e.g., "Mambular"
32+
config_class_name = f"Default{model_name}Config" # e.g., "DefaultMambularConfig"
33+
34+
try:
35+
config_module = importlib.import_module(
36+
f"{CONFIG_MODULE_PATH}.{model_name.lower()}_config"
37+
)
38+
config_class = getattr(config_module, config_class_name)
39+
return config_class() # Instantiate config
40+
except (ModuleNotFoundError, AttributeError) as e:
41+
pytest.fail(
42+
f"Could not find or instantiate config {config_class_name} for {model_name}: {e}"
43+
)
44+
45+
46+
@pytest.mark.parametrize("model_class", model_classes)
47+
def test_model_inherits_base_model(model_class):
48+
"""Test that each model correctly inherits from BaseModel."""
49+
assert issubclass(
50+
model_class, BaseModel
51+
), f"{model_class.__name__} should inherit from BaseModel."
52+
53+
54+
@pytest.mark.parametrize("model_class", model_classes)
55+
def test_model_has_forward_method(model_class):
56+
"""Test that each model has a forward method with *data."""
57+
assert hasattr(
58+
model_class, "forward"
59+
), f"{model_class.__name__} is missing a forward method."
60+
61+
sig = inspect.signature(model_class.forward)
62+
assert any(
63+
p.kind == inspect.Parameter.VAR_POSITIONAL for p in sig.parameters.values()
64+
), f"{model_class.__name__}.forward should have *data argument."
65+
66+
67+
@pytest.mark.parametrize("model_class", model_classes)
68+
def test_model_takes_config(model_class):
69+
"""Test that each model accepts a config argument."""
70+
sig = inspect.signature(model_class.__init__)
71+
assert (
72+
"config" in sig.parameters
73+
), f"{model_class.__name__} should accept a 'config' parameter."
74+
75+
76+
@pytest.mark.parametrize("model_class", model_classes)
77+
def test_model_has_num_classes(model_class):
78+
"""Test that each model accepts a num_classes argument."""
79+
sig = inspect.signature(model_class.__init__)
80+
assert (
81+
"num_classes" in sig.parameters
82+
), f"{model_class.__name__} should accept a 'num_classes' parameter."
83+
84+
85+
@pytest.mark.parametrize("model_class", model_classes)
86+
def test_model_calls_super_init(model_class):
87+
"""Test that each model calls super().__init__(config=config, **kwargs)."""
88+
source = inspect.getsource(model_class.__init__)
89+
assert (
90+
"super().__init__(config=config" in source
91+
), f"{model_class.__name__} should call super().__init__(config=config, **kwargs)."
92+
93+
94+
@pytest.mark.parametrize("model_class", model_classes)
95+
def test_model_initialization(model_class):
96+
"""Test that each model can be initialized with its correct config."""
97+
config = get_model_config(model_class)
98+
feature_info = (
99+
{
100+
"A": {
101+
"preprocessing": "imputer -> check_positive -> box-cox",
102+
"dimension": 1,
103+
"categories": None,
104+
}
105+
},
106+
{
107+
"sibsp": {
108+
"preprocessing": "imputer -> continuous_ordinal",
109+
"dimension": 1,
110+
"categories": 8,
111+
}
112+
},
113+
{},
114+
) # Mock feature info
115+
116+
try:
117+
model = model_class(
118+
feature_information=feature_info, num_classes=3, config=config
119+
)
120+
except Exception as e:
121+
pytest.fail(f"Failed to initialize {model_class.__name__}: {e}")
122+
123+
124+
@pytest.mark.parametrize("model_class", model_classes)
125+
def test_model_defines_key_attributes(model_class):
126+
"""Test that each model defines expected attributes like returns_ensemble"""
127+
config = get_model_config(model_class)
128+
feature_info = (
129+
{
130+
"A": {
131+
"preprocessing": "imputer -> check_positive -> box-cox",
132+
"dimension": 1,
133+
"categories": None,
134+
}
135+
},
136+
{
137+
"sibsp": {
138+
"preprocessing": "imputer -> continuous_ordinal",
139+
"dimension": 1,
140+
"categories": 8,
141+
}
142+
},
143+
{},
144+
) # Mock feature info
145+
146+
try:
147+
model = model_class(
148+
feature_information=feature_info, num_classes=3, config=config
149+
)
150+
except TypeError as e:
151+
pytest.fail(f"Failed to initialize {model_class.__name__}: {e}")
152+
153+
expected_attrs = ["returns_ensemble"]
154+
for attr in expected_attrs:
155+
assert hasattr(model, attr), f"{model_class.__name__} should define '{attr}'."

tests/test_classifier.py

Lines changed: 0 additions & 115 deletions
This file was deleted.

tests/test_configs.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import pytest
2+
import inspect
3+
import importlib
4+
import os
5+
import dataclasses
6+
import typing
7+
from mambular.configs.base_config import BaseConfig # Ensure correct path
8+
9+
CONFIG_MODULE_PATH = "mambular.configs"
10+
config_classes = []
11+
12+
# Discover all config classes in mambular/configs/
13+
for filename in os.listdir(os.path.dirname(__file__) + "/../mambular/configs"):
14+
if (
15+
filename.endswith(".py")
16+
and filename != "base_config.py"
17+
and not filename.startswith("__")
18+
):
19+
module_name = f"{CONFIG_MODULE_PATH}.{filename[:-3]}"
20+
module = importlib.import_module(module_name)
21+
22+
for name, obj in inspect.getmembers(module, inspect.isclass):
23+
if issubclass(obj, BaseConfig) and obj is not BaseConfig:
24+
config_classes.append(obj)
25+
26+
27+
@pytest.mark.parametrize("config_class", config_classes)
28+
def test_config_inherits_baseconfig(config_class):
29+
"""Test that each config class correctly inherits from BaseConfig."""
30+
assert issubclass(
31+
config_class, BaseConfig
32+
), f"{config_class.__name__} should inherit from BaseConfig."
33+
34+
35+
@pytest.mark.parametrize("config_class", config_classes)
36+
def test_config_instantiation(config_class):
37+
"""Test that each config class can be instantiated without errors."""
38+
try:
39+
config = config_class()
40+
except Exception as e:
41+
pytest.fail(f"Failed to instantiate {config_class.__name__}: {e}")
42+
43+
44+
@pytest.mark.parametrize("config_class", config_classes)
45+
def test_config_has_expected_attributes(config_class):
46+
"""Test that each config has all required attributes from BaseConfig."""
47+
base_attrs = {field.name for field in dataclasses.fields(BaseConfig)}
48+
config_attrs = {field.name for field in dataclasses.fields(config_class)}
49+
50+
missing_attrs = base_attrs - config_attrs
51+
assert (
52+
not missing_attrs
53+
), f"{config_class.__name__} is missing attributes: {missing_attrs}"
54+
55+
56+
@pytest.mark.parametrize("config_class", config_classes)
57+
def test_config_default_values(config_class):
58+
"""Ensure that each config class has default values assigned correctly."""
59+
config = config_class()
60+
61+
for field in dataclasses.fields(config_class):
62+
attr = field.name
63+
expected_type = field.type
64+
65+
assert hasattr(
66+
config, attr
67+
), f"{config_class.__name__} is missing attribute '{attr}'."
68+
69+
value = getattr(config, attr)
70+
71+
# Handle generic types properly
72+
origin = typing.get_origin(expected_type)
73+
74+
if origin is typing.Literal:
75+
# If the field is a Literal, ensure the value is one of the allowed options
76+
allowed_values = typing.get_args(expected_type)
77+
assert (
78+
value in allowed_values
79+
), f"{config_class.__name__}.{attr} has incorrect value: expected one of {allowed_values}, got {value}"
80+
elif origin is typing.Union:
81+
# For Union types (e.g., Optional[str]), check if value matches any type in the union
82+
allowed_types = typing.get_args(expected_type)
83+
assert any(
84+
isinstance(value, t) for t in allowed_types
85+
), f"{config_class.__name__}.{attr} has incorrect type: expected one of {allowed_types}, got {type(value)}"
86+
elif origin is not None:
87+
# If it's another generic type (e.g., list[str]), check against the base type
88+
assert (
89+
isinstance(value, origin) or value is None
90+
), f"{config_class.__name__}.{attr} has incorrect type: expected {expected_type}, got {type(value)}"
91+
else:
92+
# Standard type check
93+
assert (
94+
isinstance(value, expected_type) or value is None
95+
), f"{config_class.__name__}.{attr} has incorrect type: expected {expected_type}, got {type(value)}"
96+
97+
98+
@pytest.mark.parametrize("config_class", config_classes)
99+
def test_config_allows_updates(config_class):
100+
"""Ensure that config values can be updated and remain type-consistent."""
101+
config = config_class()
102+
103+
update_values = {
104+
"lr": 0.01,
105+
"d_model": 128,
106+
"embedding_type": "plr",
107+
"activation": lambda x: x, # Function update
108+
}
109+
110+
for attr, new_value in update_values.items():
111+
if hasattr(config, attr):
112+
setattr(config, attr, new_value)
113+
assert (
114+
getattr(config, attr) == new_value
115+
), f"{config_class.__name__}.{attr} did not update correctly."

0 commit comments

Comments
 (0)