Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion openml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
)
from .__version__ import __version__
from .datasets import OpenMLDataFeature, OpenMLDataset
from .evaluations import OpenMLEvaluation
from .evaluations import OpenMLEvaluation, list_estimation_procedures
from .flows import OpenMLFlow
from .runs import OpenMLRun
from .setups import OpenMLParameter, OpenMLSetup
Expand Down Expand Up @@ -122,6 +122,7 @@ def populate_cache(
"exceptions",
"extensions",
"flows",
"list_estimation_procedures",
"runs",
"setups",
"study",
Expand Down
8 changes: 7 additions & 1 deletion openml/evaluations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
# License: BSD 3-Clause

from .evaluation import OpenMLEvaluation
from .functions import list_evaluation_measures, list_evaluations, list_evaluations_setups
from .functions import (
list_estimation_procedures,
list_evaluation_measures,
list_evaluations,
list_evaluations_setups,
)

__all__ = [
"OpenMLEvaluation",
"list_estimation_procedures",
"list_evaluation_measures",
"list_evaluations",
"list_evaluations_setups",
Expand Down
14 changes: 7 additions & 7 deletions openml/evaluations/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,15 +298,15 @@ def list_evaluation_measures() -> list[str]:
return qualities["oml:evaluation_measures"]["oml:measures"][0]["oml:measure"]


def list_estimation_procedures() -> list[str]:
"""Return list of evaluation procedures available.
def list_estimation_procedures() -> dict[int, str]:
"""Return dictionary of evaluation procedures available.

The function performs an API call to retrieve the entire list of
evaluation procedures' names that are available.
evaluation procedures' ids and names that are available.

Returns
-------
list
dict[int, str]
"""
api_call = "estimationprocedure/list"
xml_string = openml._api_calls._perform_api_call(api_call, "get")
Expand All @@ -322,10 +322,10 @@ def list_estimation_procedures() -> list[str]:
if not isinstance(api_results["oml:estimationprocedures"]["oml:estimationprocedure"], list):
raise TypeError('Error in return XML, does not contain "oml:estimationprocedure" as a list')

return [
prod["oml:name"]
return {
int(prod["oml:id"]): prod["oml:name"]
for prod in api_results["oml:estimationprocedures"]["oml:estimationprocedure"]
]
}


def list_evaluations_setups(
Expand Down
23 changes: 23 additions & 0 deletions tests/test_evaluations/test_evaluation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,26 @@ def test_list_evaluations_setups_filter_task(self):
task_id = [6]
size = 121
self._check_list_evaluation_setups(tasks=task_id, size=size)

@pytest.mark.test_server()
def test_list_estimation_procedures_return_type(self):
procedures = openml.evaluations.list_estimation_procedures()
assert isinstance(procedures, dict)
assert len(procedures) > 0
assert all(isinstance(k, int) for k in procedures.keys())
assert all(isinstance(v, str) for v in procedures.values())

@pytest.mark.test_server()
def test_list_estimation_procedures_top_level_accessible(self):
procedures = openml.list_estimation_procedures()
assert isinstance(procedures, dict)
assert len(procedures) > 0
assert all(isinstance(k, int) for k in procedures.keys())
assert all(isinstance(v, str) for v in procedures.values())

@pytest.mark.test_server()
def test_list_estimation_procedures_valid_id_for_task_creation(self):
procedures = openml.evaluations.list_estimation_procedures()
first_id = list(procedures.keys())[0]
assert isinstance(first_id, int)
assert first_id > 0
Loading