Skip to content

Commit 4a48db3

Browse files
committed
Use array_to_params to translate categoricals back for max()
1 parent 25084fe commit 4a48db3

3 files changed

Lines changed: 22 additions & 2 deletions

File tree

bayes_opt/parameter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def to_param(self, value: float | NDArray[Float]) -> float:
195195
Any
196196
The canonical representation of the parameter.
197197
"""
198-
return value.flatten()[0]
198+
return value.flatten()[0].item()
199199

200200
def to_string(self, value: float, str_len: int) -> str:
201201
"""Represent a parameter value as a string.

bayes_opt/target_space.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -643,7 +643,7 @@ def max(self) -> dict[str, Any] | None:
643643
params = self.params[self.mask]
644644
target_max_idx = np.argmax(target)
645645

646-
res = {"target": target_max, "params": dict(zip(self.keys, params[target_max_idx]))}
646+
res = {"target": target_max, "params": self.array_to_params(params[target_max_idx])}
647647

648648
if self._constraint is not None:
649649
constraint_values = self.constraint_values[self.mask]

tests/test_target_space.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,26 @@ def test_max_with_constraint_identical_target_value():
253253
assert space.max() == {"params": {"p1": 2, "p2": 3}, "target": 5, "constraint": -1}
254254

255255

256+
def test_max_categorical() -> None:
257+
PBOUNDS = {
258+
"first_float": (0.0, 1.0),
259+
"categorical_value": ("a", "b", "c", "d"),
260+
"second_float": (0.0, 1.0),
261+
}
262+
263+
def _f(first_float: float, categorical_value: str, second_float: float) -> float:
264+
return second_float if categorical_value == "c" else first_float
265+
266+
space = TargetSpace(_f, PBOUNDS)
267+
space.probe(params={"first_float": 0.1, "categorical_value": "a", "second_float": 0.1})
268+
space.probe(params={"first_float": 0.1, "categorical_value": "b", "second_float": 0.9})
269+
space.probe(params={"first_float": 0.1, "categorical_value": "c", "second_float": 0.8})
270+
space.probe(params={"first_float": 0.1, "categorical_value": "d", "second_float": 0.9})
271+
272+
expected = {"first_float": 0.1, "categorical_value": "c", "second_float": 0.8}
273+
assert space.max()["params"] == expected
274+
275+
256276
def test_res():
257277
PBOUNDS = {"p1": (0, 10), "p2": (1, 100)}
258278
space = TargetSpace(target_func, PBOUNDS)

0 commit comments

Comments
 (0)