@@ -87,7 +87,7 @@ def _update_cost_state(client, mf_cost, log=None):
8787 from collections import defaultdict
8888 cost_metric_name = mf_cost ["metric" ]
8989 fidelity_param_name = mf_cost ["fidelity_param" ]
90- # Collect observed costs per fidelity level
90+ # Collect observed costs per fidelity level (keyed by raw fidelity value)
9191 sums = defaultdict (float )
9292 counts = defaultdict (int )
9393 df = client ._experiment .lookup_data ().df
@@ -100,19 +100,30 @@ def _update_cost_state(client, mf_cost, log=None):
100100 sub = df [(df .trial_index == trial .index ) & (df .metric_name == cost_metric_name )]
101101 if sub .empty :
102102 continue
103- sums [float ( fid_val ) ] += float (sub ["mean" ].iloc [- 1 ])
104- counts [float ( fid_val ) ] += 1
103+ sums [fid_val ] += float (sub ["mean" ].iloc [- 1 ])
104+ counts [fid_val ] += 1
105105 per_f = {fv : sums [fv ] / counts [fv ] for fv in sums if counts [fv ] > 0 }
106106 if not per_f or per_f == mf_cost ["state" ]["per_fidelity" ]:
107107 return # no change
108108 mf_cost ["state" ]["per_fidelity" ] = per_f
109- # Derive AffineFidelityCostModel params
110- fid_vals = sorted (per_f .keys ())
111- costs = [per_f [fv ] for fv in fid_vals ]
109+
110+ # Build numeric mapping for AffineFidelityCostModel.
111+ # For numeric fidelity: use values directly.
112+ # For categorical (str) fidelity: map to ordinal indices sorted by cost.
113+ raw_keys = sorted (per_f .keys (), key = lambda k : per_f [k ])
114+ is_numeric = all (isinstance (k , (int , float )) for k in raw_keys )
115+ if is_numeric :
116+ fid_vals = sorted (float (k ) for k in raw_keys )
117+ costs = [per_f [k ] for k in sorted (per_f .keys ())]
118+ else :
119+ # Map categorical levels to 0..N-1 ordered by ascending cost
120+ fid_vals = list (range (len (raw_keys )))
121+ costs = [per_f [k ] for k in raw_keys ]
122+
112123 cost_intercept = min (costs ) # cheapest fidelity floor
113124 # Weights: slope per fidelity unit. For discrete {0,1}: w = cost_high - cost_low.
114125 # For continuous: linear fit.
115- fid_range = fid_vals [- 1 ] - fid_vals [0 ]
126+ fid_range = fid_vals [- 1 ] - fid_vals [0 ] if len ( fid_vals ) > 1 else 0
116127 if fid_range > 0 :
117128 fid_weight = (max (costs ) - min (costs )) / fid_range
118129 else :
@@ -121,11 +132,12 @@ def _update_cost_state(client, mf_cost, log=None):
121132 opts = mf_cost ["acqf_opts_ref" ]
122133 opts ["cost_intercept" ] = max (cost_intercept , 1e-6 )
123134 opts ["fidelity_weights" ] = {
124- int (k ): fid_weight for k in mf_cost [ "state" ][ "per_fidelity" ]
135+ int (fv ): fid_weight for fv in fid_vals
125136 }
126137 if log :
127138 log .info ("Multi-fidelity cost updated: intercept=%.2f, weight=%.2f "
128- "(from %d fidelity levels)" , cost_intercept , fid_weight , len (per_f ))
139+ "(from %d fidelity levels: %s)" , cost_intercept , fid_weight ,
140+ len (per_f ), list (per_f .keys ()))
129141
130142
131143def optimize (cfg , debug = False ):
0 commit comments