@@ -85,6 +85,11 @@ def _update_cost_state(client, mf_cost, log=None):
8585 """
8686 from ax .core .base_trial import TrialStatus as _TS
8787 from collections import defaultdict
88+
89+ def _get_fidelity_feature_index (cl , fid_param_name ):
90+ """Return the feature index of the fidelity parameter in the search space."""
91+ param_names = list (cl ._experiment .search_space .parameters .keys ())
92+ return param_names .index (fid_param_name )
8893 cost_metric_name = mf_cost ["metric" ]
8994 fidelity_param_name = mf_cost ["fidelity_param" ]
9095 # Collect observed costs per fidelity level (keyed by raw fidelity value)
@@ -128,12 +133,15 @@ def _update_cost_state(client, mf_cost, log=None):
128133 fid_weight = (max (costs ) - min (costs )) / fid_range
129134 else :
130135 fid_weight = 1.0
131- # Write into acqf_opts — Ax passes these to the input constructor
136+ # Write into acqf_opts — Ax passes these to the input constructor.
137+ # fidelity_weights is keyed by feature index (same as target_fidelities),
138+ # NOT by fidelity values. Ax extracts the fidelity feature index from
139+ # search_space_digest.fidelity_features.
132140 opts = mf_cost ["acqf_opts_ref" ]
133141 opts ["cost_intercept" ] = max (cost_intercept , 1e-6 )
134- opts [ "fidelity_weights" ] = {
135- int ( fv ): fid_weight for fv in fid_vals
136- }
142+ # Determine the fidelity parameter's feature index
143+ fid_feature_idx = _get_fidelity_feature_index ( client , mf_cost [ "fidelity_param" ])
144+ opts [ "fidelity_weights" ] = { fid_feature_idx : fid_weight }
137145 if log :
138146 log .info ("Multi-fidelity cost updated: intercept=%.2f, weight=%.2f "
139147 "(from %d fidelity levels: %s)" , cost_intercept , fid_weight ,
@@ -273,12 +281,15 @@ def _patched_object_to_json(obj, **kwargs):
273281 _fidelity_cfg = getattr (exp_cfg , '_fidelity_params' , None ) or \
274282 getattr (type (exp_cfg ), '_fidelity_params' , None )
275283 if _fidelity_cfg :
284+ from ax .core .parameter import FixedParameter
276285 for pname , target_val in _fidelity_cfg .items ():
277286 p = client ._experiment .search_space .parameters .get (pname )
278- if p is not None :
287+ if p is not None and not isinstance ( p , FixedParameter ) :
279288 p ._is_fidelity = True
280289 p ._target_value = target_val
281290 log .info ("Fidelity parameter: %s (target_value=%s)" , pname , target_val )
291+ elif isinstance (p , FixedParameter ):
292+ log .info ("Fidelity parameter %s is fixed (specialized) — skipping MF wiring" , pname )
282293
283294 ## 2.0 Trial generation
284295 # Suppress Ax's verbose GenerationStrategy repr log
@@ -321,9 +332,16 @@ def _patched_object_to_json(obj, **kwargs):
321332 try :
322333 from botorch .models .gp_regression_fidelity import SingleTaskMultiFidelityGP
323334 from ax .generators .torch .botorch_modular .surrogate import SurrogateSpec
335+ from ax .generators .torch .botorch_modular .utils import ModelConfig
336+ from foambo .robustness import MFHVKGAcquisition
324337 gk ["surrogate_spec" ] = SurrogateSpec (
325- botorch_model_class = SingleTaskMultiFidelityGP ,
338+ model_configs = [ModelConfig (
339+ botorch_model_class = SingleTaskMultiFidelityGP ,
340+ )],
341+ allow_batched_models = False ,
326342 )
343+ # qMFHVKG input constructor doesn't accept X_pending
344+ spec .generator_kwargs ["acquisition_class" ] = MFHVKGAcquisition
327345 log .info ("Multi-fidelity: auto-selected SingleTaskMultiFidelityGP surrogate "
328346 "(fidelity param: %s)" , _fidelity_params [0 ].name )
329347 # Wire cost model for MF acquisition.
0 commit comments