Skip to content

Commit 06ac6d0

Browse files
authored
[ENH] Extend Extension class test suite (#1560)
#### Metadata * Reference Issue: fixes #1545 * New Tests Added: Yes * Documentation Updated: No * Change Log Entry: Add tests for extension interface contract and extension registry edge cases #### Details * What does this PR implement/fix? Explain your changes. This PR adds unit tests for the OpenML Extension interface and for extension registry behavior. The tests added are the 7 tests mentioned in #1545 * Why is this change necessary? What is the problem it solves? Previously, only the non-abstract registry helpers (`get_extension_by_model`, `get_extension_by_flow`) were covered. The abstract `Extension` interface itself was not tested.
1 parent 0769ff5 commit 06ac6d0

1 file changed

Lines changed: 192 additions & 47 deletions

File tree

Lines changed: 192 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
# License: BSD 3-Clause
22
from __future__ import annotations
33

4-
import inspect
4+
from collections import OrderedDict
55

6+
import inspect
7+
import numpy as np
68
import pytest
7-
9+
from unittest.mock import patch
810
import openml.testing
9-
from openml.extensions import get_extension_by_flow, get_extension_by_model, register_extension
11+
from openml.extensions import Extension, get_extension_by_flow, get_extension_by_model, register_extension
1012

1113

1214
class DummyFlow:
@@ -40,54 +42,197 @@ def can_handle_model(model):
4042
return False
4143

4244

43-
def _unregister():
44-
# "Un-register" the test extensions
45-
while True:
46-
rem_dum_ext1 = False
47-
rem_dum_ext2 = False
48-
try:
49-
openml.extensions.extensions.remove(DummyExtension1)
50-
rem_dum_ext1 = True
51-
except ValueError:
52-
pass
53-
try:
54-
openml.extensions.extensions.remove(DummyExtension2)
55-
rem_dum_ext2 = True
56-
except ValueError:
57-
pass
58-
if not rem_dum_ext1 and not rem_dum_ext2:
59-
break
45+
class DummyExtension(Extension):
46+
@classmethod
47+
def can_handle_flow(cls, flow):
48+
return isinstance(flow, DummyFlow)
49+
50+
@classmethod
51+
def can_handle_model(cls, model):
52+
return isinstance(model, DummyModel)
53+
54+
def flow_to_model(
55+
self,
56+
flow,
57+
initialize_with_defaults=False,
58+
strict_version=True,
59+
):
60+
if not isinstance(flow, DummyFlow):
61+
raise ValueError("Invalid flow")
62+
63+
model = DummyModel()
64+
model.defaults = initialize_with_defaults
65+
model.strict_version = strict_version
66+
return model
67+
68+
def model_to_flow(self, model):
69+
if not isinstance(model, DummyModel):
70+
raise ValueError("Invalid model")
71+
return DummyFlow()
72+
73+
def get_version_information(self):
74+
return ["dummy==1.0"]
75+
76+
def create_setup_string(self, model):
77+
return "DummyModel()"
78+
79+
def is_estimator(self, model):
80+
return isinstance(model, DummyModel)
81+
82+
def seed_model(self, model, seed):
83+
model.seed = seed
84+
return model
85+
86+
def _run_model_on_fold(
87+
self,
88+
model,
89+
task,
90+
X_train,
91+
rep_no,
92+
fold_no,
93+
y_train=None,
94+
X_test=None,
95+
):
96+
preds = np.zeros(len(X_train))
97+
probs = None
98+
measures = OrderedDict()
99+
trace = None
100+
return preds, probs, measures, trace
101+
102+
def obtain_parameter_values(self, flow, model=None):
103+
return []
104+
105+
def check_if_model_fitted(self, model):
106+
return False
107+
108+
def instantiate_model_from_hpo_class(self, model, trace_iteration):
109+
return DummyModel()
110+
60111

61112

62113
class TestInit(openml.testing.TestBase):
63-
def setUp(self):
64-
super().setUp()
65-
_unregister()
66114

67115
def test_get_extension_by_flow(self):
68-
assert get_extension_by_flow(DummyFlow()) is None
69-
with pytest.raises(ValueError, match="No extension registered which can handle flow:"):
70-
get_extension_by_flow(DummyFlow(), raise_if_no_extension=True)
71-
register_extension(DummyExtension1)
72-
assert isinstance(get_extension_by_flow(DummyFlow()), DummyExtension1)
73-
register_extension(DummyExtension2)
74-
assert isinstance(get_extension_by_flow(DummyFlow()), DummyExtension1)
75-
register_extension(DummyExtension1)
76-
with pytest.raises(
77-
ValueError, match="Multiple extensions registered which can handle flow:"
78-
):
79-
get_extension_by_flow(DummyFlow())
116+
# We replace the global list with a new empty list [] ONLY for this block
117+
with patch("openml.extensions.extensions", []):
118+
assert get_extension_by_flow(DummyFlow()) is None
119+
120+
with pytest.raises(ValueError, match="No extension registered which can handle flow:"):
121+
get_extension_by_flow(DummyFlow(), raise_if_no_extension=True)
122+
123+
register_extension(DummyExtension1)
124+
assert isinstance(get_extension_by_flow(DummyFlow()), DummyExtension1)
125+
126+
register_extension(DummyExtension2)
127+
assert isinstance(get_extension_by_flow(DummyFlow()), DummyExtension1)
128+
129+
register_extension(DummyExtension1)
130+
with pytest.raises(
131+
ValueError, match="Multiple extensions registered which can handle flow:"
132+
):
133+
get_extension_by_flow(DummyFlow())
80134

81135
def test_get_extension_by_model(self):
82-
assert get_extension_by_model(DummyModel()) is None
83-
with pytest.raises(ValueError, match="No extension registered which can handle model:"):
84-
get_extension_by_model(DummyModel(), raise_if_no_extension=True)
85-
register_extension(DummyExtension1)
86-
assert isinstance(get_extension_by_model(DummyModel()), DummyExtension1)
87-
register_extension(DummyExtension2)
88-
assert isinstance(get_extension_by_model(DummyModel()), DummyExtension1)
89-
register_extension(DummyExtension1)
90-
with pytest.raises(
91-
ValueError, match="Multiple extensions registered which can handle model:"
92-
):
93-
get_extension_by_model(DummyModel())
136+
# Again, we start with a fresh empty list automatically
137+
with patch("openml.extensions.extensions", []):
138+
assert get_extension_by_model(DummyModel()) is None
139+
140+
with pytest.raises(ValueError, match="No extension registered which can handle model:"):
141+
get_extension_by_model(DummyModel(), raise_if_no_extension=True)
142+
143+
register_extension(DummyExtension1)
144+
assert isinstance(get_extension_by_model(DummyModel()), DummyExtension1)
145+
146+
register_extension(DummyExtension2)
147+
assert isinstance(get_extension_by_model(DummyModel()), DummyExtension1)
148+
149+
register_extension(DummyExtension1)
150+
with pytest.raises(
151+
ValueError, match="Multiple extensions registered which can handle model:"
152+
):
153+
get_extension_by_model(DummyModel())
154+
155+
156+
def test_flow_to_model_with_defaults():
157+
"""Test flow_to_model with initialize_with_defaults=True."""
158+
ext = DummyExtension()
159+
flow = DummyFlow()
160+
161+
model = ext.flow_to_model(flow, initialize_with_defaults=True)
162+
163+
assert isinstance(model, DummyModel)
164+
assert model.defaults is True
165+
166+
def test_flow_to_model_strict_version():
167+
"""Test flow_to_model with strict_version parameter."""
168+
ext = DummyExtension()
169+
flow = DummyFlow()
170+
171+
model_strict = ext.flow_to_model(flow, strict_version=True)
172+
model_non_strict = ext.flow_to_model(flow, strict_version=False)
173+
174+
assert isinstance(model_strict, DummyModel)
175+
assert model_strict.strict_version is True
176+
177+
assert isinstance(model_non_strict, DummyModel)
178+
assert model_non_strict.strict_version is False
179+
180+
def test_model_to_flow_conversion():
181+
"""Test converting a model back to flow representation."""
182+
ext = DummyExtension()
183+
model = DummyModel()
184+
185+
flow = ext.model_to_flow(model)
186+
187+
assert isinstance(flow, DummyFlow)
188+
189+
190+
def test_invalid_flow_raises_error():
191+
"""Test that invalid flow raises appropriate error."""
192+
class InvalidFlow:
193+
pass
194+
195+
ext = DummyExtension()
196+
flow = InvalidFlow()
197+
198+
with pytest.raises(ValueError, match="Invalid flow"):
199+
ext.flow_to_model(flow)
200+
201+
202+
@patch("openml.extensions.extensions", [])
203+
def test_extension_not_found_error_message():
204+
"""Test error message contains helpful information."""
205+
class UnknownModel:
206+
pass
207+
208+
with pytest.raises(ValueError, match="No extension registered"):
209+
get_extension_by_model(UnknownModel(), raise_if_no_extension=True)
210+
211+
212+
def test_register_same_extension_twice():
213+
"""Test behavior when registering same extension twice."""
214+
# Using a context manager here to isolate the list
215+
with patch("openml.extensions.extensions", []):
216+
register_extension(DummyExtension)
217+
register_extension(DummyExtension)
218+
219+
matches = [
220+
ext for ext in openml.extensions.extensions
221+
if ext is DummyExtension
222+
]
223+
assert len(matches) == 2
224+
225+
226+
@patch("openml.extensions.extensions", [])
227+
def test_extension_priority_order():
228+
"""Test that extensions are checked in registration order."""
229+
class DummyExtensionA(DummyExtension):
230+
pass
231+
class DummyExtensionB(DummyExtension):
232+
pass
233+
234+
register_extension(DummyExtensionA)
235+
register_extension(DummyExtensionB)
236+
237+
assert openml.extensions.extensions[0] is DummyExtensionA
238+
assert openml.extensions.extensions[1] is DummyExtensionB

0 commit comments

Comments
 (0)