Skip to content

Commit ab5d916

Browse files
author
Elwardi
committed
fix: multi fidelity on categorical param with continious cost
1 parent 3281a0f commit ab5d916

1 file changed

Lines changed: 21 additions & 9 deletions

File tree

src/foambo/optimize.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def _update_cost_state(client, mf_cost, log=None):
8787
from collections import defaultdict
8888
cost_metric_name = mf_cost["metric"]
8989
fidelity_param_name = mf_cost["fidelity_param"]
90-
# Collect observed costs per fidelity level
90+
# Collect observed costs per fidelity level (keyed by raw fidelity value)
9191
sums = defaultdict(float)
9292
counts = defaultdict(int)
9393
df = client._experiment.lookup_data().df
@@ -100,19 +100,30 @@ def _update_cost_state(client, mf_cost, log=None):
100100
sub = df[(df.trial_index == trial.index) & (df.metric_name == cost_metric_name)]
101101
if sub.empty:
102102
continue
103-
sums[float(fid_val)] += float(sub["mean"].iloc[-1])
104-
counts[float(fid_val)] += 1
103+
sums[fid_val] += float(sub["mean"].iloc[-1])
104+
counts[fid_val] += 1
105105
per_f = {fv: sums[fv] / counts[fv] for fv in sums if counts[fv] > 0}
106106
if not per_f or per_f == mf_cost["state"]["per_fidelity"]:
107107
return # no change
108108
mf_cost["state"]["per_fidelity"] = per_f
109-
# Derive AffineFidelityCostModel params
110-
fid_vals = sorted(per_f.keys())
111-
costs = [per_f[fv] for fv in fid_vals]
109+
110+
# Build numeric mapping for AffineFidelityCostModel.
111+
# For numeric fidelity: use values directly.
112+
# For categorical (str) fidelity: map to ordinal indices sorted by cost.
113+
raw_keys = sorted(per_f.keys(), key=lambda k: per_f[k])
114+
is_numeric = all(isinstance(k, (int, float)) for k in raw_keys)
115+
if is_numeric:
116+
fid_vals = sorted(float(k) for k in raw_keys)
117+
costs = [per_f[k] for k in sorted(per_f.keys())]
118+
else:
119+
# Map categorical levels to 0..N-1 ordered by ascending cost
120+
fid_vals = list(range(len(raw_keys)))
121+
costs = [per_f[k] for k in raw_keys]
122+
112123
cost_intercept = min(costs) # cheapest fidelity floor
113124
# Weights: slope per fidelity unit. For discrete {0,1}: w = cost_high - cost_low.
114125
# For continuous: linear fit.
115-
fid_range = fid_vals[-1] - fid_vals[0]
126+
fid_range = fid_vals[-1] - fid_vals[0] if len(fid_vals) > 1 else 0
116127
if fid_range > 0:
117128
fid_weight = (max(costs) - min(costs)) / fid_range
118129
else:
@@ -121,11 +132,12 @@ def _update_cost_state(client, mf_cost, log=None):
121132
opts = mf_cost["acqf_opts_ref"]
122133
opts["cost_intercept"] = max(cost_intercept, 1e-6)
123134
opts["fidelity_weights"] = {
124-
int(k): fid_weight for k in mf_cost["state"]["per_fidelity"]
135+
int(fv): fid_weight for fv in fid_vals
125136
}
126137
if log:
127138
log.info("Multi-fidelity cost updated: intercept=%.2f, weight=%.2f "
128-
"(from %d fidelity levels)", cost_intercept, fid_weight, len(per_f))
139+
"(from %d fidelity levels: %s)", cost_intercept, fid_weight,
140+
len(per_f), list(per_f.keys()))
129141

130142

131143
def optimize(cfg, debug=False):

0 commit comments

Comments
 (0)