11# License: BSD 3-Clause
22from __future__ import annotations
33
4- import inspect
4+ from collections import OrderedDict
55
6+ import inspect
7+ import numpy as np
68import pytest
7-
9+ from unittest . mock import patch
810import openml .testing
9- from openml .extensions import get_extension_by_flow , get_extension_by_model , register_extension
11+ from openml .extensions import Extension , get_extension_by_flow , get_extension_by_model , register_extension
1012
1113
1214class DummyFlow :
@@ -40,54 +42,197 @@ def can_handle_model(model):
4042 return False
4143
4244
43- def _unregister ():
44- # "Un-register" the test extensions
45- while True :
46- rem_dum_ext1 = False
47- rem_dum_ext2 = False
48- try :
49- openml .extensions .extensions .remove (DummyExtension1 )
50- rem_dum_ext1 = True
51- except ValueError :
52- pass
53- try :
54- openml .extensions .extensions .remove (DummyExtension2 )
55- rem_dum_ext2 = True
56- except ValueError :
57- pass
58- if not rem_dum_ext1 and not rem_dum_ext2 :
59- break
45+ class DummyExtension (Extension ):
46+ @classmethod
47+ def can_handle_flow (cls , flow ):
48+ return isinstance (flow , DummyFlow )
49+
50+ @classmethod
51+ def can_handle_model (cls , model ):
52+ return isinstance (model , DummyModel )
53+
54+ def flow_to_model (
55+ self ,
56+ flow ,
57+ initialize_with_defaults = False ,
58+ strict_version = True ,
59+ ):
60+ if not isinstance (flow , DummyFlow ):
61+ raise ValueError ("Invalid flow" )
62+
63+ model = DummyModel ()
64+ model .defaults = initialize_with_defaults
65+ model .strict_version = strict_version
66+ return model
67+
68+ def model_to_flow (self , model ):
69+ if not isinstance (model , DummyModel ):
70+ raise ValueError ("Invalid model" )
71+ return DummyFlow ()
72+
73+ def get_version_information (self ):
74+ return ["dummy==1.0" ]
75+
76+ def create_setup_string (self , model ):
77+ return "DummyModel()"
78+
79+ def is_estimator (self , model ):
80+ return isinstance (model , DummyModel )
81+
82+ def seed_model (self , model , seed ):
83+ model .seed = seed
84+ return model
85+
86+ def _run_model_on_fold (
87+ self ,
88+ model ,
89+ task ,
90+ X_train ,
91+ rep_no ,
92+ fold_no ,
93+ y_train = None ,
94+ X_test = None ,
95+ ):
96+ preds = np .zeros (len (X_train ))
97+ probs = None
98+ measures = OrderedDict ()
99+ trace = None
100+ return preds , probs , measures , trace
101+
102+ def obtain_parameter_values (self , flow , model = None ):
103+ return []
104+
105+ def check_if_model_fitted (self , model ):
106+ return False
107+
108+ def instantiate_model_from_hpo_class (self , model , trace_iteration ):
109+ return DummyModel ()
110+
60111
61112
62113class TestInit (openml .testing .TestBase ):
63- def setUp (self ):
64- super ().setUp ()
65- _unregister ()
66114
67115 def test_get_extension_by_flow (self ):
68- assert get_extension_by_flow (DummyFlow ()) is None
69- with pytest .raises (ValueError , match = "No extension registered which can handle flow:" ):
70- get_extension_by_flow (DummyFlow (), raise_if_no_extension = True )
71- register_extension (DummyExtension1 )
72- assert isinstance (get_extension_by_flow (DummyFlow ()), DummyExtension1 )
73- register_extension (DummyExtension2 )
74- assert isinstance (get_extension_by_flow (DummyFlow ()), DummyExtension1 )
75- register_extension (DummyExtension1 )
76- with pytest .raises (
77- ValueError , match = "Multiple extensions registered which can handle flow:"
78- ):
79- get_extension_by_flow (DummyFlow ())
116+ # We replace the global list with a new empty list [] ONLY for this block
117+ with patch ("openml.extensions.extensions" , []):
118+ assert get_extension_by_flow (DummyFlow ()) is None
119+
120+ with pytest .raises (ValueError , match = "No extension registered which can handle flow:" ):
121+ get_extension_by_flow (DummyFlow (), raise_if_no_extension = True )
122+
123+ register_extension (DummyExtension1 )
124+ assert isinstance (get_extension_by_flow (DummyFlow ()), DummyExtension1 )
125+
126+ register_extension (DummyExtension2 )
127+ assert isinstance (get_extension_by_flow (DummyFlow ()), DummyExtension1 )
128+
129+ register_extension (DummyExtension1 )
130+ with pytest .raises (
131+ ValueError , match = "Multiple extensions registered which can handle flow:"
132+ ):
133+ get_extension_by_flow (DummyFlow ())
80134
81135 def test_get_extension_by_model (self ):
82- assert get_extension_by_model (DummyModel ()) is None
83- with pytest .raises (ValueError , match = "No extension registered which can handle model:" ):
84- get_extension_by_model (DummyModel (), raise_if_no_extension = True )
85- register_extension (DummyExtension1 )
86- assert isinstance (get_extension_by_model (DummyModel ()), DummyExtension1 )
87- register_extension (DummyExtension2 )
88- assert isinstance (get_extension_by_model (DummyModel ()), DummyExtension1 )
89- register_extension (DummyExtension1 )
90- with pytest .raises (
91- ValueError , match = "Multiple extensions registered which can handle model:"
92- ):
93- get_extension_by_model (DummyModel ())
136+ # Again, we start with a fresh empty list automatically
137+ with patch ("openml.extensions.extensions" , []):
138+ assert get_extension_by_model (DummyModel ()) is None
139+
140+ with pytest .raises (ValueError , match = "No extension registered which can handle model:" ):
141+ get_extension_by_model (DummyModel (), raise_if_no_extension = True )
142+
143+ register_extension (DummyExtension1 )
144+ assert isinstance (get_extension_by_model (DummyModel ()), DummyExtension1 )
145+
146+ register_extension (DummyExtension2 )
147+ assert isinstance (get_extension_by_model (DummyModel ()), DummyExtension1 )
148+
149+ register_extension (DummyExtension1 )
150+ with pytest .raises (
151+ ValueError , match = "Multiple extensions registered which can handle model:"
152+ ):
153+ get_extension_by_model (DummyModel ())
154+
155+
156+ def test_flow_to_model_with_defaults ():
157+ """Test flow_to_model with initialize_with_defaults=True."""
158+ ext = DummyExtension ()
159+ flow = DummyFlow ()
160+
161+ model = ext .flow_to_model (flow , initialize_with_defaults = True )
162+
163+ assert isinstance (model , DummyModel )
164+ assert model .defaults is True
165+
166+ def test_flow_to_model_strict_version ():
167+ """Test flow_to_model with strict_version parameter."""
168+ ext = DummyExtension ()
169+ flow = DummyFlow ()
170+
171+ model_strict = ext .flow_to_model (flow , strict_version = True )
172+ model_non_strict = ext .flow_to_model (flow , strict_version = False )
173+
174+ assert isinstance (model_strict , DummyModel )
175+ assert model_strict .strict_version is True
176+
177+ assert isinstance (model_non_strict , DummyModel )
178+ assert model_non_strict .strict_version is False
179+
180+ def test_model_to_flow_conversion ():
181+ """Test converting a model back to flow representation."""
182+ ext = DummyExtension ()
183+ model = DummyModel ()
184+
185+ flow = ext .model_to_flow (model )
186+
187+ assert isinstance (flow , DummyFlow )
188+
189+
190+ def test_invalid_flow_raises_error ():
191+ """Test that invalid flow raises appropriate error."""
192+ class InvalidFlow :
193+ pass
194+
195+ ext = DummyExtension ()
196+ flow = InvalidFlow ()
197+
198+ with pytest .raises (ValueError , match = "Invalid flow" ):
199+ ext .flow_to_model (flow )
200+
201+
202+ @patch ("openml.extensions.extensions" , [])
203+ def test_extension_not_found_error_message ():
204+ """Test error message contains helpful information."""
205+ class UnknownModel :
206+ pass
207+
208+ with pytest .raises (ValueError , match = "No extension registered" ):
209+ get_extension_by_model (UnknownModel (), raise_if_no_extension = True )
210+
211+
212+ def test_register_same_extension_twice ():
213+ """Test behavior when registering same extension twice."""
214+ # Using a context manager here to isolate the list
215+ with patch ("openml.extensions.extensions" , []):
216+ register_extension (DummyExtension )
217+ register_extension (DummyExtension )
218+
219+ matches = [
220+ ext for ext in openml .extensions .extensions
221+ if ext is DummyExtension
222+ ]
223+ assert len (matches ) == 2
224+
225+
226+ @patch ("openml.extensions.extensions" , [])
227+ def test_extension_priority_order ():
228+ """Test that extensions are checked in registration order."""
229+ class DummyExtensionA (DummyExtension ):
230+ pass
231+ class DummyExtensionB (DummyExtension ):
232+ pass
233+
234+ register_extension (DummyExtensionA )
235+ register_extension (DummyExtensionB )
236+
237+ assert openml .extensions .extensions [0 ] is DummyExtensionA
238+ assert openml .extensions .extensions [1 ] is DummyExtensionB
0 commit comments