Skip to content

Commit 5f31a73

Browse files
Fix TestClusterExecutor get_item_from_future (#910)
* Fix TestClusterExecutor get_item_from_future * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix ruff * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes * some more tests * reset * fix test * fixes --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 4a96b1c commit 5f31a73

2 files changed

Lines changed: 80 additions & 12 deletions

File tree

src/executorlib/task_scheduler/file/shared.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212

1313
class FutureItem:
14-
def __init__(self, file_name: str):
14+
def __init__(self, file_name: str, selector: Optional[int | str] = None):
1515
"""
1616
Initialize a FutureItem object.
1717
@@ -20,6 +20,7 @@ def __init__(self, file_name: str):
2020
2121
"""
2222
self._file_name = file_name
23+
self._selector = selector
2324

2425
def result(self) -> Any:
2526
"""
@@ -31,7 +32,10 @@ def result(self) -> Any:
3132
"""
3233
exec_flag, no_error_flag, result = get_output(file_name=self._file_name)
3334
if exec_flag and no_error_flag:
34-
return result
35+
if self._selector is not None:
36+
return result[self._selector]
37+
else:
38+
return result
3539
elif exec_flag:
3640
raise result
3741
else:
@@ -239,29 +243,45 @@ def _convert_args_and_kwargs(
239243
task_kwargs = {}
240244
future_wait_key_lst = []
241245
for arg in task_dict["args"]:
246+
selector = None
242247
if isinstance(arg, Future):
248+
if hasattr(arg, "_future") and hasattr(arg, "_selector"):
249+
selector = arg._selector
250+
future = arg._future
251+
else:
252+
future = arg
243253
match_found = False
244254
for k, v in memory_dict.items():
245-
if arg == v:
246-
task_args.append(FutureItem(file_name=file_name_dict[k]))
255+
if future == v:
256+
task_args.append(
257+
FutureItem(file_name=file_name_dict[k], selector=selector)
258+
)
247259
future_wait_key_lst.append(k)
248260
match_found = True
249261
break
250262
if not match_found:
251-
task_args.append(arg.result())
263+
task_args.append(future.result())
252264
else:
253265
task_args.append(arg)
254266
for key, arg in task_dict["kwargs"].items():
267+
selector = None
255268
if isinstance(arg, Future):
269+
if hasattr(arg, "_future") and hasattr(arg, "_selector"):
270+
selector = arg._selector
271+
future = arg._future
272+
else:
273+
future = arg
256274
match_found = False
257275
for k, v in memory_dict.items():
258-
if arg == v:
259-
task_kwargs[key] = FutureItem(file_name=file_name_dict[k])
276+
if future == v:
277+
task_kwargs[key] = FutureItem(
278+
file_name=file_name_dict[k], selector=selector
279+
)
260280
future_wait_key_lst.append(k)
261281
match_found = True
262282
break
263283
if not match_found:
264-
task_kwargs[key] = arg.result()
284+
task_kwargs[key] = future.result()
265285
else:
266286
task_kwargs[key] = arg
267287
return task_args, task_kwargs, future_wait_key_lst

tests/unit/task_scheduler/file/test_backend.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
import shutil
44
import unittest
55

6+
from executorlib.standalone.select import FutureSelector
7+
68

79
try:
810
from executorlib.task_scheduler.file.backend import backend_execute_task_in_file
9-
from executorlib.task_scheduler.file.shared import _check_task_output, FutureItem
11+
from executorlib.task_scheduler.file.shared import _check_task_output, _convert_args_and_kwargs, FutureItem
1012
from executorlib.standalone.hdf import dump, get_runtime
1113
from executorlib.standalone.serialize import serialize_funct
1214

@@ -19,6 +21,14 @@ def my_funct(a, b):
1921
return a + b
2022

2123

24+
def return_dict(a, b):
25+
return {"a": a, "b": b}
26+
27+
28+
def return_list(a, b):
29+
return [a, b]
30+
31+
2232
def get_error(a):
2333
raise ValueError(a)
2434

@@ -36,7 +46,6 @@ def test_execute_function_mixed(self):
3646
fn_kwargs={"b": 2},
3747
)
3848
file_name = os.path.join(cache_directory, task_key + "_i.h5")
39-
os.makedirs(cache_directory, exist_ok=True)
4049
dump(file_name=file_name, data_dict=data_dict)
4150
backend_execute_task_in_file(file_name=file_name)
4251
future_obj = Future()
@@ -55,6 +64,47 @@ def test_execute_function_mixed(self):
5564
self.assertTrue(future_file_obj.done())
5665
self.assertEqual(future_file_obj.result(), 3)
5766

67+
def test_execute_function_mixed_selector_convert(self):
68+
cache_directory = os.path.abspath("executorlib_cache")
69+
os.makedirs(cache_directory, exist_ok=True)
70+
task_key_1, data_dict = serialize_funct(
71+
fn=return_dict,
72+
fn_args=[1],
73+
fn_kwargs={"b": 2},
74+
)
75+
file_name_1 = os.path.join(cache_directory, task_key_1 + "_i.h5")
76+
dump(file_name=file_name_1, data_dict=data_dict)
77+
backend_execute_task_in_file(file_name=file_name_1)
78+
f1 = Future()
79+
_check_task_output(
80+
task_key=task_key_1, future_obj=f1, cache_directory=cache_directory
81+
)
82+
task_key_2, data_dict = serialize_funct(
83+
fn=return_list,
84+
fn_args=[1],
85+
fn_kwargs={"b": 2},
86+
)
87+
file_name_2 = os.path.join(cache_directory, task_key_2 + "_i.h5")
88+
dump(file_name=file_name_2, data_dict=data_dict)
89+
backend_execute_task_in_file(file_name=file_name_2)
90+
f2 = Future()
91+
_check_task_output(
92+
task_key=task_key_2, future_obj=f2, cache_directory=cache_directory
93+
)
94+
fs1 = FutureSelector(future=f1, selector="a")
95+
fs2 = FutureSelector(future=f2, selector=1)
96+
task_args, task_kwargs, future_wait_key_lst = _convert_args_and_kwargs(
97+
task_dict={"fn": 1, "args": (fs1,), "kwargs": {"b": fs2}},
98+
memory_dict={task_key_1: f1, task_key_2: f2},
99+
file_name_dict={
100+
task_key_1: os.path.join(cache_directory, task_key_1 + "_o.h5"),
101+
task_key_2: os.path.join(cache_directory, task_key_2 + "_o.h5"),
102+
},
103+
)
104+
self.assertEqual(task_args[0].result(), 1)
105+
self.assertEqual(task_kwargs["b"].result(), 2)
106+
self.assertTrue(len(future_wait_key_lst) == 2)
107+
58108
def test_execute_function_args(self):
59109
cache_directory = os.path.abspath("executorlib_cache")
60110
os.makedirs(cache_directory, exist_ok=True)
@@ -92,7 +142,6 @@ def test_execute_function_kwargs(self):
92142
fn_kwargs={"a": 1, "b": 2},
93143
)
94144
file_name = os.path.join(cache_directory, task_key + "_i.h5")
95-
os.makedirs(cache_directory, exist_ok=True)
96145
dump(file_name=file_name, data_dict=data_dict)
97146
backend_execute_task_in_file(file_name=file_name)
98147
future_obj = Future()
@@ -120,7 +169,6 @@ def test_execute_function_error(self):
120169
fn_kwargs={"a": 1},
121170
)
122171
file_name = os.path.join(cache_directory, task_key + "_i.h5")
123-
os.makedirs(cache_directory, exist_ok=True)
124172
data_dict["error_log_file"] = os.path.join(cache_directory, "error.out")
125173
dump(file_name=file_name, data_dict=data_dict)
126174
backend_execute_task_in_file(file_name=file_name)

0 commit comments

Comments
 (0)